提交 802f362a 编写于 作者: T tensor-tang

unify the kernelfuncs cache and add unit test

test=develop
上级 36e2d324
......@@ -82,8 +82,9 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> {
Tensor track;
int* track_value =
track.mutable_data<int>(emission_dims, platform::CPUPlace());
auto ker = jit::Get<jit::kCRFDecoding, jit::CRFDecodingTuples<T>,
platform::CPUPlace>(tag_num);
auto ker = jit::KernelFuncs<jit::kCRFDecoding, jit::CRFDecodingTuples<T>,
platform::CPUPlace>::Cache()
.At(tag_num);
ker(static_cast<int>(seq_len), x, w, alpha_value, track_value, tag_num);
T max_score = -std::numeric_limits<T>::max();
int max_i = 0;
......
......@@ -110,8 +110,10 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
constexpr int simd_width = 16;
int C = c / simd_width;
auto multiply = jit::Get<jit::kNCHW16CMulNC, jit::NCHW16CMulNCTuples<T>,
platform::CPUPlace>(0);
auto multiply =
jit::KernelFuncs<jit::kNCHW16CMulNC, jit::NCHW16CMulNCTuples<T>,
platform::CPUPlace>::Cache()
.At(0);
#pragma omp parallel for collapse(2)
for (int ni = 0; ni < n; ni++) {
for (int ci = 0; ci < C; ci++) {
......
......@@ -52,8 +52,10 @@ struct EmbeddingVSumFunctor {
out_width, jit::SeqPoolType::kSum);
for (size_t i = 0; i != ids_lod.size() - 1; ++i) {
attr.index_height = ids_lod[i + 1] - ids_lod[i];
auto emb_seqpool = jit::Get<jit::kEmbSeqPool, jit::EmbSeqPoolTuples<T>,
platform::CPUPlace>(attr);
auto emb_seqpool =
jit::KernelFuncs<jit::kEmbSeqPool, jit::EmbSeqPoolTuples<T>,
platform::CPUPlace>::Cache()
.At(attr);
emb_seqpool(table, ids + ids_lod[i] * idx_width, output + i * out_width,
&attr);
}
......@@ -135,8 +137,10 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> {
T *d_table_data = d_table_value->mutable_data<T>(context.GetPlace());
const T *d_output_data = d_output->data<T>();
auto vbroadcast = jit::Get<jit::kVBroadcast, jit::VBroadcastTuples<T>,
platform::CPUPlace>(out_width);
auto vbroadcast =
jit::KernelFuncs<jit::kVBroadcast, jit::VBroadcastTuples<T>,
platform::CPUPlace>::Cache()
.At(out_width);
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]);
const T *src = d_output_data + i * out_width;
......
......@@ -182,29 +182,32 @@ class FusionGRUKernel : public framework::OpKernel<T> {
const int total_T = x_dims[0]; \
const int D3 = wh_dims[1]
#define INIT_OTHER_DEFINES \
auto* h0 = ctx.Input<Tensor>("H0"); \
auto* wx = ctx.Input<Tensor>("WeightX"); \
auto* bias = ctx.Input<Tensor>("Bias"); \
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
bool is_reverse = ctx.Attr<bool>("is_reverse"); \
const int M = x_dims[1]; \
const int D = wh_dims[0]; \
const int D2 = D * 2; \
const jit::gru_attr_t attr( \
D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")), \
jit::to_kerneltype(ctx.Attr<std::string>("activation"))); \
jit::gru_t one_step; \
auto ComputeH1 = \
jit::Get<jit::kGRUH1, jit::GRUTuples<T>, platform::CPUPlace>(attr); \
auto ComputeHtPart1 = \
jit::Get<jit::kGRUHtPart1, jit::GRUTuples<T>, platform::CPUPlace>(attr); \
auto ComputeHtPart2 = \
jit::Get<jit::kGRUHtPart2, jit::GRUTuples<T>, platform::CPUPlace>(attr); \
const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \
auto place = ctx.GetPlace(); \
#define INIT_OTHER_DEFINES \
auto* h0 = ctx.Input<Tensor>("H0"); \
auto* wx = ctx.Input<Tensor>("WeightX"); \
auto* bias = ctx.Input<Tensor>("Bias"); \
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
bool is_reverse = ctx.Attr<bool>("is_reverse"); \
const int M = x_dims[1]; \
const int D = wh_dims[0]; \
const int D2 = D * 2; \
const jit::gru_attr_t attr( \
D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")), \
jit::to_kerneltype(ctx.Attr<std::string>("activation"))); \
jit::gru_t one_step; \
auto ComputeH1 = jit::KernelFuncs<jit::kGRUH1, jit::GRUTuples<T>, \
platform::CPUPlace>::Cache() \
.At(attr); \
auto ComputeHtPart1 = jit::KernelFuncs<jit::kGRUHtPart1, jit::GRUTuples<T>, \
platform::CPUPlace>::Cache() \
.At(attr); \
auto ComputeHtPart2 = jit::KernelFuncs<jit::kGRUHtPart2, jit::GRUTuples<T>, \
platform::CPUPlace>::Cache() \
.At(attr); \
const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \
auto place = ctx.GetPlace(); \
T* xx_data = xx->mutable_data<T>(place)
void SeqCompute(const framework::ExecutionContext& ctx) const {
......
......@@ -235,32 +235,34 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
const int D = wh_dims[0]; \
const int D4 = wh_dims[1]
#define INIT_OTHER_DEFINES \
const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \
/* diagonal weight*/ \
const T* wp_data = bias->data<T>() + D4; \
/* for peephole only*/ \
T* checked_cell_data = nullptr; \
auto place = ctx.GetPlace(); \
if (use_peepholes) { \
/* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \
auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \
checked_cell_data = checked_cell->mutable_data<T>(place); \
} \
const jit::lstm_attr_t attr( \
D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")), \
jit::to_kerneltype(ctx.Attr<std::string>("candidate_activation")), \
jit::to_kerneltype(ctx.Attr<std::string>("cell_activation")), \
use_peepholes); \
jit::lstm_t one_step; \
one_step.wp = wp_data; \
one_step.checked = checked_cell_data; \
auto ComputeC1H1 = \
jit::Get<jit::kLSTMC1H1, jit::LSTMTuples<T>, platform::CPUPlace>(attr); \
auto ComputeCtHt = \
jit::Get<jit::kLSTMCtHt, jit::LSTMTuples<T>, platform::CPUPlace>(attr)
#define INIT_OTHER_DEFINES \
const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \
/* diagonal weight*/ \
const T* wp_data = bias->data<T>() + D4; \
/* for peephole only*/ \
T* checked_cell_data = nullptr; \
auto place = ctx.GetPlace(); \
if (use_peepholes) { \
/* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \
auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \
checked_cell_data = checked_cell->mutable_data<T>(place); \
} \
const jit::lstm_attr_t attr( \
D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")), \
jit::to_kerneltype(ctx.Attr<std::string>("candidate_activation")), \
jit::to_kerneltype(ctx.Attr<std::string>("cell_activation")), \
use_peepholes); \
jit::lstm_t one_step; \
one_step.wp = wp_data; \
one_step.checked = checked_cell_data; \
auto ComputeC1H1 = jit::KernelFuncs<jit::kLSTMC1H1, jit::LSTMTuples<T>, \
platform::CPUPlace>::Cache() \
.At(attr); \
auto ComputeCtHt = jit::KernelFuncs<jit::kLSTMCtHt, jit::LSTMTuples<T>, \
platform::CPUPlace>::Cache() \
.At(attr)
// Wh GEMM
#define GEMM_WH_ADDON(bs, prev, out) \
......
......@@ -81,10 +81,12 @@ void FusionRepeatedFCReluOpMaker::Make() {
template <typename T>
static void fc_relu(const T* x, const T* w, const T* b, T* y,
const jit::matmul_attr_t& attr) {
auto matmul =
jit::Get<jit::kMatMul, jit::MatMulTuples<T>, platform::CPUPlace>(attr);
auto addbias_relu =
jit::Get<jit::kVAddRelu, jit::XYZNTuples<T>, platform::CPUPlace>(attr.n);
auto matmul = jit::KernelFuncs<jit::kMatMul, jit::MatMulTuples<T>,
platform::CPUPlace>::Cache()
.At(attr);
auto addbias_relu = jit::KernelFuncs<jit::kVAddRelu, jit::XYZNTuples<T>,
platform::CPUPlace>::Cache()
.At(attr.n);
matmul(x, w, y, &attr);
T* dst = y;
for (int i = 0; i < attr.m; ++i) {
......
......@@ -97,9 +97,9 @@ class FusionSeqPoolConcatKernel : public framework::OpKernel<T> {
} else if (pooltype == "SQRT") {
attr.type = jit::SeqPoolType::kSqrt;
}
auto seqpool =
jit::Get<jit::kSeqPool, jit::SeqPoolTuples<T>, platform::CPUPlace>(
attr);
auto seqpool = jit::KernelFuncs<jit::kSeqPool, jit::SeqPoolTuples<T>,
platform::CPUPlace>::Cache()
.At(attr);
size_t n = ins.size();
size_t dst_step_size = n * w;
for (size_t i = 0; i < n; ++i) {
......
......@@ -93,20 +93,24 @@ class FusionSquaredMatSubKernel : public framework::OpKernel<T> {
attr.n = y_dims[1];
int o_numel = attr.m * attr.n;
auto vsquare_x =
jit::Get<jit::kVSquare, jit::XYNTuples<T>, platform::CPUPlace>(attr.m *
attr.k);
auto vsquare_y =
jit::Get<jit::kVSquare, jit::XYNTuples<T>, platform::CPUPlace>(attr.k *
attr.n);
auto vsquare_xy =
jit::Get<jit::kVSquare, jit::XYNTuples<T>, platform::CPUPlace>(o_numel);
auto vsub =
jit::Get<jit::kVSub, jit::XYZNTuples<T>, platform::CPUPlace>(o_numel);
auto vscal =
jit::Get<jit::kVScal, jit::AXYNTuples<T>, platform::CPUPlace>(o_numel);
auto matmul =
jit::Get<jit::kMatMul, jit::MatMulTuples<T>, platform::CPUPlace>(attr);
auto vsquare_x = jit::KernelFuncs<jit::kVSquare, jit::XYNTuples<T>,
platform::CPUPlace>::Cache()
.At(attr.m * attr.k);
auto vsquare_y = jit::KernelFuncs<jit::kVSquare, jit::XYNTuples<T>,
platform::CPUPlace>::Cache()
.At(attr.k * attr.n);
auto vsquare_xy = jit::KernelFuncs<jit::kVSquare, jit::XYNTuples<T>,
platform::CPUPlace>::Cache()
.At(o_numel);
auto vsub = jit::KernelFuncs<jit::kVSub, jit::XYZNTuples<T>,
platform::CPUPlace>::Cache()
.At(o_numel);
auto vscal = jit::KernelFuncs<jit::kVScal, jit::AXYNTuples<T>,
platform::CPUPlace>::Cache()
.At(o_numel);
auto matmul = jit::KernelFuncs<jit::kMatMul, jit::MatMulTuples<T>,
platform::CPUPlace>::Cache()
.At(attr);
const T* x_data = x->data<T>();
const T* y_data = y->data<T>();
......
......@@ -5,7 +5,7 @@ file(APPEND ${jit_file} "\#pragma once\n")
file(APPEND ${jit_file} "\#include \"paddle/fluid/operators/jit/helper.h\"\n")
file(APPEND ${jit_file} "\#include \"paddle/fluid/operators/jit/registry.h\"\n\n")
set(JIT_KERNEL_DEPS cpu_info cblas gflags enforce place)
set(JIT_KERNEL_DEPS cpu_info cblas gflags enforce place xxhash)
file(GLOB jit_kernel_cc_srcs RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.cc")
list(REMOVE_ITEM jit_kernel_cc_srcs test.cc benchmark.cc)
......
......@@ -142,7 +142,7 @@ void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
}
}
// Test result from Get function
auto tgt = jit::Get<KT, KernelTuples, PlaceType>(attr);
auto tgt = jit::KernelFuncs<KT, KernelTuples, PlaceType>::Cache().At(attr);
if (!tgt) {
LOG(FATAL) << "Target can not be empty!";
}
......
......@@ -14,6 +14,9 @@
#pragma once
extern "C" {
#include <xxhash.h>
}
#include <iostream>
#include <string>
#include <vector>
......@@ -127,23 +130,36 @@ class KernelFuncs {
return g_func_cache;
}
bool Has(int key) const { return funcs_.find(key) != funcs_.end(); }
void Insert(int key, typename KernelTuples::func_type func) {
funcs_.emplace(key, func);
}
typename KernelTuples::func_type At(int key) {
// the exposed interface to use
typename KernelTuples::func_type At(
const typename KernelTuples::attr_type& attr) {
// XXH64: 13.8 GB/s
int64_t key = XXH64(&attr, sizeof(typename KernelTuples::attr_type), 0);
if (Has(key)) {
return funcs_.at(key);
}
auto func = Get<KT, KernelTuples, PlaceType>(key);
// If do not have this attr in cache,
// then could run some runtime benchmark of this attr and save the best one.
// Here just get the offline benchmarked best one.
auto func = Get<KT, KernelTuples, PlaceType>(attr);
Insert(key, func);
return func;
}
typename KernelTuples::func_type operator[](
const typename KernelTuples::attr_type& attr) {
return At(attr);
}
protected:
bool Has(int64_t key) const { return funcs_.find(key) != funcs_.end(); }
void Insert(int64_t key, typename KernelTuples::func_type func) {
funcs_.emplace(key, func);
}
private:
std::unordered_map<int, typename KernelTuples::func_type> funcs_;
std::unordered_map<int64_t, typename KernelTuples::func_type> funcs_;
DISABLE_COPY_AND_ASSIGN(KernelFuncs);
};
......
......@@ -462,7 +462,7 @@ void TestAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
}
// test result from Get function
// VLOG(10) << "Test Get function ";
auto tgt = jit::Get<KT, KernelTuples, PlaceType>(attr);
auto tgt = jit::KernelFuncs<KT, KernelTuples, PlaceType>::Cache().At(attr);
test(tgt, args...);
}
......@@ -845,7 +845,9 @@ void TestKernelNCHW16CMulNCTuples() {
T* zjit_data = zjit.data();
constexpr int simd_width = ZMM_FLOAT_BLOCK;
int C = c / simd_width;
auto tgt = jit::Get<KT, jit::NCHW16CMulNCTuples<T>, PlaceType>(0);
auto tgt =
jit::KernelFuncs<KT, jit::NCHW16CMulNCTuples<T>, PlaceType>::Cache().At(
0);
auto jitcode = jit::GetJitCode<KT, jit::NCHW16CMulNCTuples<T>, PlaceType>(0);
EXPECT_TRUE(tgt != nullptr);
......@@ -967,10 +969,10 @@ void TestKernelVBroadcastTuples() {
}
}
#define TEST_CPU_KERNEL(test_tuple, kernel_type) \
TEST(JITKernel, kernel_type) { \
TestKernel##test_tuple<jit::kernel_type, float, CPUPlace>(); \
TestKernel##test_tuple<jit::kernel_type, float, CPUPlace>(); \
#define TEST_CPU_KERNEL(test_tuple, kernel_type) \
TEST(JITKernel, kernel_type) { \
TestKernel##test_tuple<jit::kernel_type, float, CPUPlace>(); \
TestKernel##test_tuple<jit::kernel_type, double, CPUPlace>(); \
}
TEST_CPU_KERNEL(XYZNTuples, kVMul);
......@@ -1041,4 +1043,18 @@ TEST(JITKernel_key, gru) {
EXPECT_TRUE(key2 == key3);
EXPECT_TRUE(key3 != key4);
}
// TODO(TJ): add more test about key and pool
TEST(JITKernel, kernel_func) {
auto f1 =
jit::KernelFuncs<jit::kVAdd, jit::XYZNTuples<float>, CPUPlace>::Cache()
.At(3);
auto f2 = jit::KernelFuncs<jit::kVAdd, jit::XYZNTuples<float>,
CPUPlace>::Cache()[3];
EXPECT_TRUE(f1 == f2);
f1 = jit::KernelFuncs<jit::kVAdd, jit::XYZNTuples<float>, CPUPlace>::Cache()
.At(3);
f2 = jit::KernelFuncs<jit::kVAdd, jit::XYZNTuples<float>, CPUPlace>::Cache()
.At(4);
EXPECT_TRUE(f1 != f2);
}
......@@ -229,9 +229,9 @@ class LayerNormKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(scale->numel(), right);
PADDLE_ENFORCE_EQ(bias->numel(), right);
auto ker =
jit::Get<jit::kLayerNorm, jit::LayerNormTuples<T>, platform::CPUPlace>(
right);
auto ker = jit::KernelFuncs<jit::kLayerNorm, jit::LayerNormTuples<T>,
platform::CPUPlace>::Cache()
.At(right);
ker(x.data<T>(), out.data<T>(), mean->data<T>(), var->data<T>(),
scale->data<T>(), bias->data<T>(), static_cast<int>(left),
static_cast<const float>(epsilon), right);
......
......@@ -255,9 +255,9 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
jit::seq_pool_attr_t attr(
static_cast<int>(input.numel() / input.dims()[0]),
jit::SeqPoolType::kSum);
auto seqpool =
jit::Get<jit::kSeqPool, jit::SeqPoolTuples<T>, platform::CPUPlace>(
attr);
auto seqpool = jit::KernelFuncs<jit::kSeqPool, jit::SeqPoolTuples<T>,
platform::CPUPlace>::Cache()
.At(attr);
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
attr.h = static_cast<int>(lod[i + 1] - lod[i]);
seqpool(src, dst, &attr);
......
......@@ -47,8 +47,9 @@ class SGDOpKernel : public framework::OpKernel<T> {
int64_t rows_idx = 0;
T *out_data = param_out->mutable_data<T>(ctx.GetPlace());
auto sgd =
jit::Get<jit::kSgd, jit::SgdTuples<T>, platform::CPUPlace>(attr);
auto sgd = jit::KernelFuncs<jit::kSgd, jit::SgdTuples<T>,
platform::CPUPlace>::Cache()
.At(attr);
sgd(lr, param_data, grad_data, &rows_idx, out_data, &attr);
} else if (grad_var->IsType<framework::SelectedRows>()) {
// TODO(qijun): In Sparse SGD operator, in-place update is enforced.
......@@ -81,8 +82,9 @@ class SGDOpKernel : public framework::OpKernel<T> {
attr.selected_rows_size = grad_rows.size();
PADDLE_ENFORCE_EQ(attr.grad_width, attr.param_width);
auto sgd =
jit::Get<jit::kSgd, jit::SgdTuples<T>, platform::CPUPlace>(attr);
auto sgd = jit::KernelFuncs<jit::kSgd, jit::SgdTuples<T>,
platform::CPUPlace>::Cache()
.At(attr);
sgd(lr, param_data, grad_data, rows_data, out_data, &attr);
} else {
PADDLE_THROW("Unsupported Variable Type of Grad");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册