未验证 提交 ccc7c358 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #16104 from tensor-tang/refine/jit

refine jitkernels and test
......@@ -82,8 +82,9 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> {
Tensor track;
int* track_value =
track.mutable_data<int>(emission_dims, platform::CPUPlace());
auto ker = jit::Get<jit::kCRFDecoding, jit::CRFDecodingTuples<T>,
platform::CPUPlace>(tag_num);
auto ker =
jit::KernelFuncs<jit::CRFDecodingTuple<T>, platform::CPUPlace>::Cache()
.At(tag_num);
ker(static_cast<int>(seq_len), x, w, alpha_value, track_value, tag_num);
T max_score = -std::numeric_limits<T>::max();
int max_i = 0;
......
......@@ -110,8 +110,9 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
constexpr int simd_width = 16;
int C = c / simd_width;
auto multiply = jit::Get<jit::kNCHW16CMulNC, jit::NCHW16CMulNCTuples<T>,
platform::CPUPlace>(0);
auto multiply = jit::KernelFuncs<jit::NCHW16CMulNCTuple<T>,
platform::CPUPlace>::Cache()
.At(0);
#pragma omp parallel for collapse(2)
for (int ni = 0; ni < n; ni++) {
for (int ci = 0; ci < C; ci++) {
......
......@@ -52,8 +52,9 @@ struct EmbeddingVSumFunctor {
out_width, jit::SeqPoolType::kSum);
for (size_t i = 0; i != ids_lod.size() - 1; ++i) {
attr.index_height = ids_lod[i + 1] - ids_lod[i];
auto emb_seqpool = jit::Get<jit::kEmbSeqPool, jit::EmbSeqPoolTuples<T>,
platform::CPUPlace>(attr);
auto emb_seqpool =
jit::KernelFuncs<jit::EmbSeqPoolTuple<T>, platform::CPUPlace>::Cache()
.At(attr);
emb_seqpool(table, ids + ids_lod[i] * idx_width, output + i * out_width,
&attr);
}
......@@ -135,8 +136,9 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> {
T *d_table_data = d_table_value->mutable_data<T>(context.GetPlace());
const T *d_output_data = d_output->data<T>();
auto vbroadcast = jit::Get<jit::kVBroadcast, jit::VBroadcastTuples<T>,
platform::CPUPlace>(out_width);
auto vbroadcast =
jit::KernelFuncs<jit::VBroadcastTuple<T>, platform::CPUPlace>::Cache()
.At(out_width);
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]);
const T *src = d_output_data + i * out_width;
......
......@@ -196,11 +196,14 @@ class FusionGRUKernel : public framework::OpKernel<T> {
jit::to_kerneltype(ctx.Attr<std::string>("activation"))); \
jit::gru_t one_step; \
auto ComputeH1 = \
jit::Get<jit::kGRUH1, jit::GRUTuples<T>, platform::CPUPlace>(attr); \
jit::KernelFuncs<jit::GRUH1Tuple<T>, platform::CPUPlace>::Cache().At( \
attr); \
auto ComputeHtPart1 = \
jit::Get<jit::kGRUHtPart1, jit::GRUTuples<T>, platform::CPUPlace>(attr); \
jit::KernelFuncs<jit::GRUHtPart1Tuple<T>, platform::CPUPlace>::Cache() \
.At(attr); \
auto ComputeHtPart2 = \
jit::Get<jit::kGRUHtPart2, jit::GRUTuples<T>, platform::CPUPlace>(attr); \
jit::KernelFuncs<jit::GRUHtPart2Tuple<T>, platform::CPUPlace>::Cache() \
.At(attr); \
const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \
......
......@@ -258,9 +258,11 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
one_step.wp = wp_data; \
one_step.checked = checked_cell_data; \
auto ComputeC1H1 = \
jit::Get<jit::kLSTMC1H1, jit::LSTMTuples<T>, platform::CPUPlace>(attr); \
jit::KernelFuncs<jit::LSTMC1H1Tuple<T>, platform::CPUPlace>::Cache().At( \
attr); \
auto ComputeCtHt = \
jit::Get<jit::kLSTMCtHt, jit::LSTMTuples<T>, platform::CPUPlace>(attr)
jit::KernelFuncs<jit::LSTMCtHtTuple<T>, platform::CPUPlace>::Cache().At( \
attr)
// Wh GEMM
#define GEMM_WH_ADDON(bs, prev, out) \
......
......@@ -82,9 +82,11 @@ template <typename T>
static void fc_relu(const T* x, const T* w, const T* b, T* y,
const jit::matmul_attr_t& attr) {
auto matmul =
jit::Get<jit::kMatMul, jit::MatMulTuples<T>, platform::CPUPlace>(attr);
jit::KernelFuncs<jit::MatMulTuple<T>, platform::CPUPlace>::Cache().At(
attr);
auto addbias_relu =
jit::Get<jit::kVAddRelu, jit::XYZNTuples<T>, platform::CPUPlace>(attr.n);
jit::KernelFuncs<jit::VAddReluTuple<T>, platform::CPUPlace>::Cache().At(
attr.n);
matmul(x, w, y, &attr);
T* dst = y;
for (int i = 0; i < attr.m; ++i) {
......
......@@ -98,7 +98,7 @@ class FusionSeqPoolConcatKernel : public framework::OpKernel<T> {
attr.type = jit::SeqPoolType::kSqrt;
}
auto seqpool =
jit::Get<jit::kSeqPool, jit::SeqPoolTuples<T>, platform::CPUPlace>(
jit::KernelFuncs<jit::SeqPoolTuple<T>, platform::CPUPlace>::Cache().At(
attr);
size_t n = ins.size();
size_t dst_step_size = n * w;
......
......@@ -94,19 +94,23 @@ class FusionSquaredMatSubKernel : public framework::OpKernel<T> {
int o_numel = attr.m * attr.n;
auto vsquare_x =
jit::Get<jit::kVSquare, jit::XYNTuples<T>, platform::CPUPlace>(attr.m *
attr.k);
jit::KernelFuncs<jit::VSquareTuple<T>, platform::CPUPlace>::Cache().At(
attr.m * attr.k);
auto vsquare_y =
jit::Get<jit::kVSquare, jit::XYNTuples<T>, platform::CPUPlace>(attr.k *
attr.n);
jit::KernelFuncs<jit::VSquareTuple<T>, platform::CPUPlace>::Cache().At(
attr.k * attr.n);
auto vsquare_xy =
jit::Get<jit::kVSquare, jit::XYNTuples<T>, platform::CPUPlace>(o_numel);
jit::KernelFuncs<jit::VSquareTuple<T>, platform::CPUPlace>::Cache().At(
o_numel);
auto vsub =
jit::Get<jit::kVSub, jit::XYZNTuples<T>, platform::CPUPlace>(o_numel);
jit::KernelFuncs<jit::VSubTuple<T>, platform::CPUPlace>::Cache().At(
o_numel);
auto vscal =
jit::Get<jit::kVScal, jit::AXYNTuples<T>, platform::CPUPlace>(o_numel);
jit::KernelFuncs<jit::VScalTuple<T>, platform::CPUPlace>::Cache().At(
o_numel);
auto matmul =
jit::Get<jit::kMatMul, jit::MatMulTuples<T>, platform::CPUPlace>(attr);
jit::KernelFuncs<jit::MatMulTuple<T>, platform::CPUPlace>::Cache().At(
attr);
const T* x_data = x->data<T>();
const T* y_data = y->data<T>();
......
......@@ -5,7 +5,7 @@ file(APPEND ${jit_file} "\#pragma once\n")
file(APPEND ${jit_file} "\#include \"paddle/fluid/operators/jit/helper.h\"\n")
file(APPEND ${jit_file} "\#include \"paddle/fluid/operators/jit/registry.h\"\n\n")
set(JIT_KERNEL_DEPS cpu_info cblas gflags enforce place)
set(JIT_KERNEL_DEPS cpu_info cblas gflags enforce place xxhash)
file(GLOB jit_kernel_cc_srcs RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.cc")
list(REMOVE_ITEM jit_kernel_cc_srcs test.cc benchmark.cc)
......
......@@ -59,8 +59,6 @@ BenchJITKernel* InsertBenchmark(BenchJITKernel* b) {
InsertBenchmark(new BenchJITKernel_##name##_##dtype##_##place##_()); \
void BenchJITKernel_##name##_##dtype##_##place##_::Run()
#define BENCH_FP32_CPU(name) BENCH_JITKERNEL(name, FP32, CPU)
void RUN_ALL_BENCHMARK() {
for (auto p : g_all_benchmarks) {
if (!FLAGS_filter.empty() && FLAGS_filter != p->Name()) {
......@@ -90,11 +88,11 @@ std::vector<int> TestSizes() {
return s;
}
template <typename KernelTuples, typename... Args>
template <typename KernelTuple, typename... Args>
struct BenchFunc {
// return this function avg time
// TODO(TJ): clear cache every time
double operator()(const typename KernelTuples::func_type tgt, Args... args) {
double operator()(const typename KernelTuple::func_type tgt, Args... args) {
for (int i = 0; i < FLAGS_burning; ++i) {
tgt(args...);
}
......@@ -109,40 +107,17 @@ struct BenchFunc {
namespace jit = paddle::operators::jit;
template <jit::KernelType KT, typename KernelTuples, typename PlaceType,
typename... Args>
void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
BenchFunc<KernelTuples, Args...> benchmark;
template <typename KernelTuple, typename PlaceType, typename... Args>
void BenchAllImpls(const typename KernelTuple::attr_type& attr, Args... args) {
BenchFunc<KernelTuple, Args...> benchmark;
std::vector<std::pair<std::string, double>> infos;
// test refer
auto refer = jit::GetRefer<KT, KernelTuples>();
if (!refer) {
LOG(FATAL) << "Refer can not be empty!";
auto funcs = jit::GetAllCandidateFuncsWithTypes<KernelTuple, PlaceType>(attr);
for (auto f : funcs) {
infos.push_back(std::make_pair(f.first, benchmark(f.second, args...)));
}
infos.push_back(std::make_pair("Refer", benchmark(refer, args...)));
// test jitcode
auto jitcode = jit::GetJitCode<KT, KernelTuples, PlaceType>(attr);
if (jitcode) {
infos.push_back(std::make_pair("JitCode", benchmark(jitcode, args...)));
}
// test all impls in more
jit::KernelKey kkey(KT, PlaceType());
auto& pool = jit::KernelPool().Instance().AllKernels();
auto iter = pool.find(kkey);
if (iter != pool.end()) {
auto& impls = iter->second;
for (auto& impl : impls) {
auto i = dynamic_cast<const jit::KernelMore<KernelTuples>*>(impl.get());
if (i && i->UseMe(attr)) {
auto more = i->GetFunc();
infos.push_back(
std::make_pair(i->ImplType(), benchmark(more, args...)));
}
}
}
// Test result from Get function
auto tgt = jit::Get<KT, KernelTuples, PlaceType>(attr);
auto tgt = jit::KernelFuncs<KernelTuple, PlaceType>::Cache().At(attr);
if (!tgt) {
LOG(FATAL) << "Target can not be empty!";
}
......@@ -150,7 +125,8 @@ void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
// print
std::ostringstream loginfos;
loginfos << "Kernel Type " << jit::to_string(KT) << ": " << attr << ": ";
loginfos << "Kernel Type " << jit::to_string(KernelTuple::kernel_type) << ": "
<< attr << ": ";
for (auto pair : infos) {
loginfos << pair.first << " takes " << pair.second << " us; ";
}
......@@ -159,8 +135,9 @@ void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
using Tensor = paddle::framework::Tensor;
template <jit::KernelType KT, typename T, typename PlaceType>
void BenchXYZNKernel() {
template <typename KernelTuple, typename PlaceType>
void BenchKernelXYZN() {
using T = typename KernelTuple::data_type;
for (int d : TestSizes()) {
Tensor x, y, z;
x.Resize({d});
......@@ -171,16 +148,16 @@ void BenchXYZNKernel() {
T* z_data = z.mutable_data<T>(PlaceType());
RandomVec<T>(d, x_data);
RandomVec<T>(d, y_data);
BenchAllImpls<KT, jit::XYZNTuples<T>, PlaceType>(d, x.data<T>(),
y.data<T>(), z_data, d);
BenchAllImpls<KernelTuple, PlaceType>(d, x.data<T>(), y.data<T>(), z_data,
d);
// test inplace
BenchAllImpls<KT, jit::XYZNTuples<T>, PlaceType>(d, x.data<T>(), z_data,
z_data, d);
BenchAllImpls<KernelTuple, PlaceType>(d, x.data<T>(), z_data, z_data, d);
}
}
template <jit::KernelType KT, typename T, typename PlaceType>
void BenchAXYNKernel() {
template <typename KernelTuple, typename PlaceType>
void BenchKernelAXYN() {
using T = typename KernelTuple::data_type;
for (int d : TestSizes()) {
const T a = static_cast<T>(3);
Tensor x, y;
......@@ -189,26 +166,26 @@ void BenchAXYNKernel() {
T* x_data = x.mutable_data<T>(PlaceType());
T* y_data = y.mutable_data<T>(PlaceType());
RandomVec<T>(d, x_data);
BenchAllImpls<KT, jit::AXYNTuples<T>, PlaceType>(d, &a, x.data<T>(), y_data,
d);
BenchAllImpls<KernelTuple, PlaceType>(d, &a, x.data<T>(), y_data, d);
// test inplace
BenchAllImpls<KT, jit::AXYNTuples<T>, PlaceType>(d, &a, x.data<T>(), x_data,
d);
BenchAllImpls<KernelTuple, PlaceType>(d, &a, x.data<T>(), x_data, d);
}
}
template <jit::KernelType KT, typename T, typename PlaceType>
void BenchXRNKernel() {
template <typename KernelTuple, typename PlaceType>
void BenchKernelXRN() {
using T = typename KernelTuple::data_type;
for (int d : TestSizes()) {
Tensor x;
RandomVec<T>(d, x.mutable_data<T>({d}, PlaceType()));
T res;
BenchAllImpls<KT, jit::XRNTuples<T>, PlaceType>(d, x.data<T>(), &res, d);
BenchAllImpls<KernelTuple, PlaceType>(d, x.data<T>(), &res, d);
}
}
template <jit::KernelType KT, typename T, typename PlaceType>
void BenchXYNKernel() {
template <typename KernelTuple, typename PlaceType>
void BenchKernelXYN() {
using T = typename KernelTuple::data_type;
for (int d : TestSizes()) {
Tensor x, y;
x.Resize({d});
......@@ -216,12 +193,13 @@ void BenchXYNKernel() {
T* x_data = x.mutable_data<T>(PlaceType());
T* y_data = y.mutable_data<T>(PlaceType());
RandomVec<T>(d, x_data);
BenchAllImpls<KT, jit::XYNTuples<T>, PlaceType>(d, x.data<T>(), y_data, d);
BenchAllImpls<KernelTuple, PlaceType>(d, x.data<T>(), y_data, d);
}
}
template <jit::KernelType KT, typename T, typename PlaceType>
void BenchLSTMKernel() {
template <typename KernelTuple, typename PlaceType>
void BenchKernelLSTM() {
using T = typename KernelTuple::data_type;
for (bool use_peephole : {true, false}) {
for (int d : TestSizes()) {
const jit::lstm_attr_t attr(d, jit::kVSigmoid, jit::kVTanh, jit::kVTanh,
......@@ -252,13 +230,14 @@ void BenchLSTMKernel() {
step.wp = wp_data;
step.checked = checked_data;
}
BenchAllImpls<KT, jit::LSTMTuples<T>, PlaceType>(attr, &step, &attr);
BenchAllImpls<KernelTuple, PlaceType>(attr, &step, &attr);
}
}
}
template <jit::KernelType KT, typename T, typename PlaceType>
void BenchGRUKernel() {
template <typename KernelTuple, typename PlaceType>
void BenchKernelGRU() {
using T = typename KernelTuple::data_type;
for (int d : TestSizes()) {
const jit::gru_attr_t attr(d, jit::kVSigmoid, jit::kVTanh);
auto place = PlaceType();
......@@ -275,12 +254,13 @@ void BenchGRUKernel() {
step.gates = x_data;
step.ht_1 = ht_1_data;
step.ht = ht_data;
BenchAllImpls<KT, jit::GRUTuples<T>, PlaceType>(attr, &step, &attr);
BenchAllImpls<KernelTuple, PlaceType>(attr, &step, &attr);
}
}
template <jit::KernelType KT, typename T, typename PlaceType>
void BenchSeqPoolKernel() {
template <typename KernelTuple, typename PlaceType>
void BenchKernelSeqPool() {
using T = typename KernelTuple::data_type;
std::vector<jit::SeqPoolType> pool_types = {
jit::SeqPoolType::kSum, jit::SeqPoolType::kAvg, jit::SeqPoolType::kSqrt};
for (auto type : pool_types) {
......@@ -294,15 +274,15 @@ void BenchSeqPoolKernel() {
RandomVec<T>(h * w, x.mutable_data<T>(PlaceType()), -2.f, 2.f);
const T* x_data = x.data<T>();
T* y_data = y.mutable_data<T>(PlaceType());
BenchAllImpls<KT, jit::SeqPoolTuples<T>, PlaceType>(attr, x_data,
y_data, &attr);
BenchAllImpls<KernelTuple, PlaceType>(attr, x_data, y_data, &attr);
}
}
}
}
template <jit::KernelType KT, typename T, typename PlaceType>
void BenchEmbSeqPoolKernel() {
template <typename KernelTuple, typename PlaceType>
void BenchKernelEmbSeqPool() {
using T = typename KernelTuple::data_type;
std::vector<jit::SeqPoolType> pool_types = {jit::SeqPoolType::kSum};
int64_t tbl_h = 1e4;
for (int tbl_w : {10, 16, 256}) {
......@@ -324,16 +304,17 @@ void BenchEmbSeqPoolKernel() {
tbl_h - 1);
const int64_t* idx_data = idx.data<int64_t>();
T* o_data = out.mutable_data<T>(PlaceType());
BenchAllImpls<KT, jit::EmbSeqPoolTuples<T>, PlaceType>(
attr, table_data, idx_data, o_data, &attr);
BenchAllImpls<KernelTuple, PlaceType>(attr, table_data, idx_data,
o_data, &attr);
}
}
}
}
}
template <jit::KernelType KT, typename T, typename PlaceType>
void BenchSgdKernel() {
template <typename KernelTuple, typename PlaceType>
void BenchKernelSgd() {
using T = typename KernelTuple::data_type;
const T lr = 0.1;
auto UnDuplicatedRandomVec = [](int n, const int64_t lower,
const int64_t upper) -> std::vector<int64_t> {
......@@ -364,15 +345,16 @@ void BenchSgdKernel() {
const T* grad_data = grad.data<T>();
const int64_t* rows_data = rows.data();
jit::sgd_attr_t attr(param_h, grad_w, rows_size, grad_w, rows_size);
BenchAllImpls<KT, jit::SgdTuples<T>, PlaceType>(
attr, &lr, param_data, grad_data, rows_data, param_data, &attr);
BenchAllImpls<KernelTuple, PlaceType>(attr, &lr, param_data, grad_data,
rows_data, param_data, &attr);
}
}
}
}
template <jit::KernelType KT, typename T, typename PlaceType>
void BenchMatMulKernel() {
template <typename KernelTuple, typename PlaceType>
void BenchKernelMatMul() {
using T = typename KernelTuple::data_type;
for (int m : {1, 2, 3, 4}) {
for (int n : TestSizes()) {
for (int k : TestSizes()) {
......@@ -386,15 +368,16 @@ void BenchMatMulKernel() {
const T* b_data = b.data<T>();
T* c_data = c.mutable_data<T>(PlaceType());
const jit::matmul_attr_t attr{m, n, k};
BenchAllImpls<KT, jit::MatMulTuples<T>, PlaceType>(attr, a_data, b_data,
c_data, &attr);
BenchAllImpls<KernelTuple, PlaceType>(attr, a_data, b_data, c_data,
&attr);
}
}
}
}
template <jit::KernelType KT, typename T, typename PlaceType>
void BenchSoftmaxKernel() {
template <typename KernelTuple, typename PlaceType>
void BenchKernelSoftmax() {
using T = typename KernelTuple::data_type;
for (int bs : {1, 2, 10}) {
for (int n : TestSizes()) {
Tensor x, y;
......@@ -403,14 +386,14 @@ void BenchSoftmaxKernel() {
RandomVec<T>(bs * n, x.mutable_data<T>(PlaceType()), -2.f, 2.f);
const T* x_data = x.data<T>();
T* y_data = y.mutable_data<T>(PlaceType());
BenchAllImpls<KT, jit::SoftmaxTuples<T>, PlaceType>(n, x_data, y_data, n,
bs);
BenchAllImpls<KernelTuple, PlaceType>(n, x_data, y_data, n, bs);
}
}
}
template <jit::KernelType KT, typename T, typename PlaceType>
void BenchLayerNormKernel() {
template <typename KernelTuple, typename PlaceType>
void BenchKernelLayerNorm() {
using T = typename KernelTuple::data_type;
const T epsilon = 9.99999975e-06;
for (int n : {1, 2, 10}) {
for (int x_dim_0 : {1, 9, 17, 50}) {
......@@ -439,16 +422,17 @@ void BenchLayerNormKernel() {
T* var_data = var.data<T>();
T* out_data = out.mutable_data<T>(PlaceType());
BenchAllImpls<KT, jit::LayerNormTuples<T>, PlaceType>(
right, x_data, out_data, mean_data, var_data, scale_data, bias_data,
left, epsilon, right);
BenchAllImpls<KernelTuple, PlaceType>(right, x_data, out_data,
mean_data, var_data, scale_data,
bias_data, left, epsilon, right);
}
}
}
}
template <jit::KernelType KT, typename T, typename PlaceType>
void BenchCRFDecodingKernel() {
template <typename KernelTuple, typename PlaceType>
void BenchKernelCRFDecoding() {
using T = typename KernelTuple::data_type;
constexpr int state_trans_base_idx = 2;
for (int seq_len : {1, 11, 17, 50}) {
for (int tag_num : TestSizes()) {
......@@ -468,14 +452,15 @@ void BenchCRFDecodingKernel() {
T* alpha_data = alpha.mutable_data<T>(PlaceType());
int* track_data = track.mutable_data<int>(PlaceType());
BenchAllImpls<KT, jit::CRFDecodingTuples<T>, PlaceType>(
tag_num, seq_len, x_data, w_data, alpha_data, track_data, tag_num);
BenchAllImpls<KernelTuple, PlaceType>(tag_num, seq_len, x_data, w_data,
alpha_data, track_data, tag_num);
}
}
}
template <jit::KernelType KT, typename T, typename PlaceType>
void BenchVBroadcastKernel() {
template <typename KernelTuple, typename PlaceType>
void BenchKernelVBroadcast() {
using T = typename KernelTuple::data_type;
for (int64_t w : {1, 16, 64, 100, 256}) {
Tensor x;
x.Resize({w});
......@@ -485,78 +470,86 @@ void BenchVBroadcastKernel() {
Tensor y;
y.Resize({h * w});
T* y_data = y.mutable_data<T>(PlaceType());
BenchAllImpls<KT, jit::VBroadcastTuples<T>, PlaceType>(
w, x_data, y_data, static_cast<int64_t>(h), w);
BenchAllImpls<KernelTuple, PlaceType>(w, x_data, y_data,
static_cast<int64_t>(h), w);
}
}
}
using T = float;
using CPUPlace = paddle::platform::CPUPlace;
#define BenchKernelVMul BenchKernelXYZN
#define BenchKernelVAdd BenchKernelXYZN
#define BenchKernelVAddRelu BenchKernelXYZN
#define BenchKernelVSub BenchKernelXYZN
// xyzn
BENCH_FP32_CPU(kVMul) { BenchXYZNKernel<jit::kVMul, T, CPUPlace>(); }
BENCH_FP32_CPU(kVAdd) { BenchXYZNKernel<jit::kVAdd, T, CPUPlace>(); }
BENCH_FP32_CPU(kVAddRelu) { BenchXYZNKernel<jit::kVAddRelu, T, CPUPlace>(); }
BENCH_FP32_CPU(kVSub) { BenchXYZNKernel<jit::kVSub, T, CPUPlace>(); }
#define BenchKernelVScal BenchKernelAXYN
#define BenchKernelVAddBias BenchKernelAXYN
// axyn
BENCH_FP32_CPU(kVScal) { BenchAXYNKernel<jit::kVScal, T, CPUPlace>(); }
BENCH_FP32_CPU(kVAddBias) { BenchAXYNKernel<jit::kVAddBias, T, CPUPlace>(); }
#define BenchKernelVRelu BenchKernelXYN
#define BenchKernelVIdentity BenchKernelXYN
#define BenchKernelVSquare BenchKernelXYN
#define BenchKernelVExp BenchKernelXYN
#define BenchKernelVSigmoid BenchKernelXYN
#define BenchKernelVTanh BenchKernelXYN
#define BenchKernelVCopy BenchKernelXYN
// xrn
BENCH_FP32_CPU(kHSum) { BenchXRNKernel<jit::kHSum, T, CPUPlace>(); }
BENCH_FP32_CPU(kHMax) { BenchXRNKernel<jit::kHMax, T, CPUPlace>(); }
#define BenchKernelHMax BenchKernelXRN
#define BenchKernelHSum BenchKernelXRN
// xyn
BENCH_FP32_CPU(kVRelu) { BenchXYNKernel<jit::kVRelu, T, CPUPlace>(); }
BENCH_FP32_CPU(kVIdentity) { BenchXYNKernel<jit::kVIdentity, T, CPUPlace>(); }
BENCH_FP32_CPU(kVSquare) { BenchXYNKernel<jit::kVSquare, T, CPUPlace>(); }
BENCH_FP32_CPU(kVExp) { BenchXYNKernel<jit::kVExp, T, CPUPlace>(); }
BENCH_FP32_CPU(kVSigmoid) { BenchXYNKernel<jit::kVSigmoid, T, CPUPlace>(); }
BENCH_FP32_CPU(kVTanh) { BenchXYNKernel<jit::kVTanh, T, CPUPlace>(); }
BENCH_FP32_CPU(kVCopy) { BenchXYNKernel<jit::kVCopy, T, CPUPlace>(); }
// lstm and peephole
BENCH_FP32_CPU(kLSTMCtHt) { BenchLSTMKernel<jit::kLSTMCtHt, T, CPUPlace>(); }
BENCH_FP32_CPU(kLSTMC1H1) { BenchLSTMKernel<jit::kLSTMC1H1, T, CPUPlace>(); }
// gru functions
BENCH_FP32_CPU(kGRUH1) { BenchGRUKernel<jit::kGRUH1, T, CPUPlace>(); }
BENCH_FP32_CPU(kGRUHtPart1) { BenchGRUKernel<jit::kGRUHtPart1, T, CPUPlace>(); }
BENCH_FP32_CPU(kGRUHtPart2) { BenchGRUKernel<jit::kGRUHtPart2, T, CPUPlace>(); }
// seq pool function
BENCH_FP32_CPU(kSeqPool) { BenchSeqPoolKernel<jit::kSeqPool, T, CPUPlace>(); }
// embedding seq pool function
BENCH_FP32_CPU(kEmbSeqPool) {
BenchEmbSeqPoolKernel<jit::kEmbSeqPool, T, CPUPlace>();
}
#define BenchKernelLSTMCtHt BenchKernelLSTM
#define BenchKernelLSTMC1H1 BenchKernelLSTM
// sgd function
BENCH_FP32_CPU(kSgd) { BenchSgdKernel<jit::kSgd, T, CPUPlace>(); }
#define BenchKernelGRUH1 BenchKernelGRU
#define BenchKernelGRUHtPart1 BenchKernelGRU
#define BenchKernelGRUHtPart2 BenchKernelGRU
// matmul
BENCH_FP32_CPU(kMatMul) { BenchMatMulKernel<jit::kMatMul, T, CPUPlace>(); }
using CPUPlace = paddle::platform::CPUPlace;
// softmax
BENCH_FP32_CPU(kSoftmax) { BenchSoftmaxKernel<jit::kSoftmax, T, CPUPlace>(); }
#define BENCH_FP32_CPU(name) \
BENCH_JITKERNEL(name, FP32, CPU) { \
BenchKernel##name<jit::name##Tuple<float>, CPUPlace>(); \
}
// layernorm
BENCH_FP32_CPU(kLayerNorm) {
BenchLayerNormKernel<jit::kLayerNorm, T, CPUPlace>();
}
// xyzn
BENCH_FP32_CPU(VMul);
BENCH_FP32_CPU(VAdd);
BENCH_FP32_CPU(VAddRelu);
BENCH_FP32_CPU(VSub);
// crfdecoding
BENCH_FP32_CPU(kCRFDecoding) {
BenchCRFDecodingKernel<jit::kCRFDecoding, T, CPUPlace>();
}
// axyn
BENCH_FP32_CPU(VScal);
BENCH_FP32_CPU(VAddBias);
// vbroadcast function
BENCH_FP32_CPU(kVBroadcast) {
BenchVBroadcastKernel<jit::kVBroadcast, T, CPUPlace>();
}
// xyn
BENCH_FP32_CPU(VRelu);
BENCH_FP32_CPU(VIdentity);
BENCH_FP32_CPU(VSquare);
BENCH_FP32_CPU(VExp);
BENCH_FP32_CPU(VSigmoid);
BENCH_FP32_CPU(VTanh);
BENCH_FP32_CPU(VCopy);
// xrn
BENCH_FP32_CPU(HMax);
BENCH_FP32_CPU(HSum);
// LSTM
BENCH_FP32_CPU(LSTMCtHt);
BENCH_FP32_CPU(LSTMC1H1);
// GRU
BENCH_FP32_CPU(GRUH1);
BENCH_FP32_CPU(GRUHtPart1);
BENCH_FP32_CPU(GRUHtPart2);
BENCH_FP32_CPU(LayerNorm);
BENCH_FP32_CPU(CRFDecoding);
BENCH_FP32_CPU(SeqPool);
BENCH_FP32_CPU(EmbSeqPool);
BENCH_FP32_CPU(MatMul);
BENCH_FP32_CPU(Softmax);
BENCH_FP32_CPU(Sgd);
BENCH_FP32_CPU(VBroadcast);
// Benchmark all jit kernels including jitcode, mkl and refer.
// To use this tool, run command: ./benchmark [options...]
......
......@@ -13,6 +13,7 @@
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/act.h"
#include <memory>
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
......@@ -81,7 +82,7 @@ void VActJitCode::genCode() {
#define DECLARE_ACT_CREATOR(name) \
class name##Creator : public JitCodeCreator<int> { \
public: \
bool UseMe(const int& attr) const override; \
bool CanBeUsed(const int& attr) const override; \
size_t CodeSize(const int& d) const override; \
std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override { \
return make_unique<name##JitCode>(attr, CodeSize(attr)); \
......@@ -96,27 +97,27 @@ DECLARE_ACT_CREATOR(VSigmoid);
DECLARE_ACT_CREATOR(VTanh);
// TODO(TJ): tuning use me
bool VReluCreator::UseMe(const int& d) const {
bool VReluCreator::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx);
}
bool VSquareCreator::UseMe(const int& d) const {
bool VSquareCreator::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx);
}
bool VIdentityCreator::UseMe(const int& d) const {
bool VIdentityCreator::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx);
}
bool VExpCreator::UseMe(const int& d) const {
bool VExpCreator::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx) && d < 32;
}
bool VSigmoidCreator::UseMe(const int& d) const {
bool VSigmoidCreator::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx);
}
bool VTanhCreator::UseMe(const int& d) const {
bool VTanhCreator::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx);
}
......
......@@ -13,6 +13,7 @@
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/blas.h"
#include <memory>
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
......@@ -142,7 +143,7 @@ void NCHW16CMulNCJitCode::genCode() {
class NCHW16CMulNCCreator : public JitCodeCreator<int> {
public:
bool UseMe(const int& attr) const override {
bool CanBeUsed(const int& attr) const override {
return platform::MayIUse(platform::avx512f);
}
size_t CodeSize(const int& d) const override { return 256 * 1024; }
......@@ -154,7 +155,7 @@ class NCHW16CMulNCCreator : public JitCodeCreator<int> {
#define DECLARE_BLAS_CREATOR(name) \
class name##Creator : public JitCodeCreator<int> { \
public: \
bool UseMe(const int& attr) const override { \
bool CanBeUsed(const int& attr) const override { \
return platform::MayIUse(platform::avx) && attr <= 1024; \
} \
size_t CodeSize(const int& d) const override { \
......
......@@ -14,6 +14,7 @@
#include "paddle/fluid/operators/jit/gen/embseqpool.h"
#include <stddef.h> // offsetof
#include <memory>
#include <vector>
#include "paddle/fluid/operators/jit/gen/act.h" // for exp_float_consts ones
#include "paddle/fluid/operators/jit/registry.h"
......@@ -121,7 +122,7 @@ void EmbSeqPoolJitCode::genCode() {
class EmbSeqPoolCreator : public JitCodeCreator<emb_seq_pool_attr_t> {
public:
bool UseMe(const emb_seq_pool_attr_t& attr) const override {
bool CanBeUsed(const emb_seq_pool_attr_t& attr) const override {
return platform::MayIUse(platform::avx) &&
attr.table_width % YMM_FLOAT_BLOCK == 0;
}
......
......@@ -14,6 +14,7 @@
#include "paddle/fluid/operators/jit/gen/gru.h"
#include <stddef.h> // offsetof
#include <memory>
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
......@@ -86,7 +87,7 @@ void GRUJitCode::genCode() {
class name##Creator : public JitCodeCreator<gru_attr_t> { \
public: \
/* TODO(TJ): enable more */ \
bool UseMe(const gru_attr_t& attr) const override { \
bool CanBeUsed(const gru_attr_t& attr) const override { \
return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \
} \
size_t CodeSize(const gru_attr_t& attr) const override { \
......
......@@ -13,6 +13,7 @@
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/hopv.h"
#include <memory>
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
......@@ -76,7 +77,7 @@ void HOPVJitCode::genCode() {
#define DECLARE_HOP_CREATOR(name) \
class name##Creator : public JitCodeCreator<int> { \
public: \
bool UseMe(const int& attr) const override { \
bool CanBeUsed(const int& attr) const override { \
return platform::MayIUse(platform::avx); \
} \
size_t CodeSize(const int& d) const override { \
......
......@@ -73,7 +73,7 @@ class JitCode : public GenBase, public Xbyak::CodeGenerator {
virtual void genCode() = 0;
size_t getSize() const override { return CodeGenerator::getSize(); }
const unsigned char* getCodeInternal() override {
const unsigned char* getCodeInternal() const override {
const Xbyak::uint8* code = CodeGenerator::getCode();
return code;
}
......
......@@ -14,6 +14,7 @@
#include "paddle/fluid/operators/jit/gen/lstm.h"
#include <stddef.h> // offsetof
#include <memory>
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
......@@ -114,7 +115,7 @@ void LSTMJitCode::genCode() {
class name##Creator : public JitCodeCreator<lstm_attr_t> { \
public: \
/* TODO(TJ): enable more */ \
bool UseMe(const lstm_attr_t& attr) const override { \
bool CanBeUsed(const lstm_attr_t& attr) const override { \
return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \
} \
size_t CodeSize(const lstm_attr_t& attr) const override { \
......
......@@ -14,8 +14,8 @@
#include "paddle/fluid/operators/jit/gen/matmul.h"
#include <stddef.h> // offsetof
#include <memory>
#include <vector>
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
......@@ -98,7 +98,7 @@ void MatMulJitCode::genCode() {
class MatMulCreator : public JitCodeCreator<matmul_attr_t> {
public:
bool UseMe(const matmul_attr_t& attr) const override {
bool CanBeUsed(const matmul_attr_t& attr) const override {
return attr.m == 1 && platform::MayIUse(platform::avx512f) &&
attr.n % ZMM_FLOAT_BLOCK == 0 && attr.k < 512;
}
......
......@@ -13,6 +13,7 @@
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/seqpool.h"
#include <memory>
#include "paddle/fluid/operators/jit/gen/act.h" // for exp_float_consts ones
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
......@@ -57,7 +58,7 @@ void SeqPoolJitCode::genCode() {
class SeqPoolCreator : public JitCodeCreator<seq_pool_attr_t> {
public:
bool UseMe(const seq_pool_attr_t& attr) const override {
bool CanBeUsed(const seq_pool_attr_t& attr) const override {
return platform::MayIUse(platform::avx);
}
size_t CodeSize(const seq_pool_attr_t& attr) const override {
......
......@@ -14,6 +14,7 @@
#include "paddle/fluid/operators/jit/gen/sgd.h"
#include <stddef.h> // offsetof
#include <memory>
#include <vector>
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
......@@ -104,7 +105,7 @@ void SgdJitCode::genCode() {
class SgdCreator : public JitCodeCreator<sgd_attr_t> {
public:
bool UseMe(const sgd_attr_t& attr) const override {
bool CanBeUsed(const sgd_attr_t& attr) const override {
return platform::MayIUse(platform::avx) &&
attr.grad_width % YMM_FLOAT_BLOCK == 0;
}
......
......@@ -69,7 +69,7 @@ void VBroadcastJitCode::genCode() {
class VBroadcastCreator : public JitCodeCreator<int64_t> {
public:
bool UseMe(const int64_t& w) const override {
bool CanBeUsed(const int64_t& w) const override {
return platform::MayIUse(platform::avx) && w % YMM_FLOAT_BLOCK == 0;
}
size_t CodeSize(const int64_t& w) const override {
......
......@@ -31,7 +31,7 @@ namespace paddle {
namespace operators {
namespace jit {
// refer do not need useme, it would be the last one.
// refer do not need CanBeUsed, it would be the last one.
void GenBase::dumpCode(const unsigned char* code) const {
if (code) {
static int counter = 0;
......
......@@ -31,9 +31,10 @@ class GenBase : public Kernel {
virtual ~GenBase() = default;
virtual std::string name() const = 0;
virtual size_t getSize() const = 0;
virtual const unsigned char* getCodeInternal() = 0;
virtual const unsigned char* getCodeInternal() const = 0;
const char* ImplType() const override { return "JitCode"; }
template <typename Func>
Func getCode() {
Func getCode() const {
const unsigned char* code = this->getCodeInternal();
if (FLAGS_dump_jitcode) {
this->dumpCode(code);
......@@ -65,7 +66,7 @@ class JitCodeCreator : public GenCreator {
virtual ~JitCodeCreator() = default;
// condition when this jit code can be used.
virtual bool UseMe(const Attr& attr) const = 0;
virtual bool CanBeUsed(const Attr& attr) const = 0;
// estimate this code size
virtual size_t CodeSize(const Attr& attr) const = 0;
......
......@@ -16,6 +16,8 @@
#include <iostream>
#include <string>
#include <unordered_map>
#include <utility> // for std::move
#include <vector>
#include "paddle/fluid/operators/jit/gen_base.h"
#include "paddle/fluid/operators/jit/kernel_base.h"
......@@ -27,35 +29,34 @@ namespace paddle {
namespace operators {
namespace jit {
template <KernelType KT, typename KernelTuples, typename PlaceType>
template <typename KernelTuple, typename PlaceType>
inline typename std::enable_if<
std::is_same<typename KernelTuples::data_type, float>::value &&
std::is_same<typename KernelTuple::data_type, float>::value &&
std::is_same<PlaceType, platform::CPUPlace>::value,
typename KernelTuples::func_type>::type
GetJitCode(const typename KernelTuples::attr_type& attr) {
using Func = typename KernelTuples::func_type;
using Attr = typename KernelTuples::attr_type;
size_t key = JitCodeKey<Attr>(attr);
auto& codes = JitCodePool<KT>().Instance();
const Kernel*>::type
GetJitCode(const typename KernelTuple::attr_type& attr) {
using Attr = typename KernelTuple::attr_type;
int64_t key = JitCodeKey<Attr>(attr);
auto& codes = JitCodePool<KernelTuple::kernel_type>::Instance();
if (codes.Has(key)) {
return codes.AllKernels().at(key)->template getCode<Func>();
return codes.AllKernels().at(key).get();
}
// creator is not related with attr, so can use KernelKey as key
KernelKey kkey(KT, PlaceType());
KernelKey kkey(KernelTuple::kernel_type, PlaceType());
// pool: (KernelKey(type, place), vector<GenCreatorPtr>)
auto& creator_map = JitCodeCreatorPool().Instance().AllCreators();
auto& creator_map = JitCodeCreatorPool::Instance().AllCreators();
auto iter = creator_map.find(kkey);
if (iter != creator_map.end()) {
auto& creators = iter->second;
for (auto& cur : creators) {
auto i = dynamic_cast<const JitCodeCreator<Attr>*>(cur.get());
if (i && i->UseMe(attr)) {
if (i && i->CanBeUsed(attr)) {
auto p = i->CreateJitCode(attr);
if (p) {
auto f = p->template getCode<Func>();
auto res = p.get();
codes.Insert(key, std::move(p));
return f;
return res;
}
}
}
......@@ -63,87 +64,153 @@ GetJitCode(const typename KernelTuples::attr_type& attr) {
return nullptr;
}
template <KernelType KT, typename KernelTuples, typename PlaceType>
template <typename KernelTuple, typename PlaceType>
inline typename std::enable_if<
!std::is_same<typename KernelTuples::data_type, float>::value ||
!std::is_same<typename KernelTuple::data_type, float>::value ||
!std::is_same<PlaceType, platform::CPUPlace>::value,
typename KernelTuples::func_type>::type
GetJitCode(const typename KernelTuples::attr_type& attr) {
const Kernel*>::type
GetJitCode(const typename KernelTuple::attr_type& attr) {
return nullptr;
}
// Refer code do not related with attr, which is just for cast
// Refer is always on CPUPlace
template <KernelType KT, typename KernelTuples>
inline typename KernelTuples::func_type GetRefer() {
auto& ref_pool = ReferKernelPool().Instance().AllKernels();
KernelKey kkey(KT, platform::CPUPlace());
template <typename KernelTuple>
inline const Kernel* GetReferKernel() {
auto& ref_pool = ReferKernelPool::Instance().AllKernels();
KernelKey kkey(KernelTuple::kernel_type, platform::CPUPlace());
auto ref_iter = ref_pool.find(kkey);
PADDLE_ENFORCE(ref_iter != ref_pool.end(),
"Every Kernel should have reference function.");
auto& ref_impls = ref_iter->second;
for (auto& impl : ref_impls) {
auto i = dynamic_cast<const ReferKernel<KernelTuples>*>(impl.get());
auto i = dynamic_cast<const ReferKernel<KernelTuple>*>(impl.get());
if (i) {
return i->GetFunc();
return i;
}
}
return nullptr;
}
template <KernelType KT, typename KernelTuples,
typename PlaceType = platform::CPUPlace>
typename KernelTuples::func_type Get(
const typename KernelTuples::attr_type& attr) {
auto jitfunc = GetJitCode<KT, KernelTuples, PlaceType>(attr);
if (jitfunc) {
return jitfunc;
template <typename KernelTuple>
inline typename KernelTuple::func_type GetReferFunc() {
auto ker = GetReferKernel<KernelTuple>();
auto p = dynamic_cast<const ReferKernel<KernelTuple>*>(ker);
PADDLE_ENFORCE(p, "The Refer kernel should exsit");
return p->GetFunc();
}
// Return all Kernels that can be used
template <typename KernelTuple, typename PlaceType>
std::vector<const Kernel*> GetAllCandidateKernels(
const typename KernelTuple::attr_type& attr) {
// the search order shoudl be jitcode > more > refer
std::vector<const Kernel*> res;
auto jitker = GetJitCode<KernelTuple, PlaceType>(attr);
if (jitker) {
res.emplace_back(jitker);
}
// pool: (KernelKey(type, place), vector<KernelPtr>)
KernelKey kkey(KT, PlaceType());
auto& pool = KernelPool().Instance().AllKernels();
// more kernelpool: (KernelKey(type, place), vector<KernelPtr>)
KernelKey kkey(KernelTuple::kernel_type, PlaceType());
auto& pool = KernelPool::Instance().AllKernels();
auto iter = pool.find(kkey);
if (iter != pool.end()) {
auto& impls = iter->second;
for (auto& impl : impls) {
auto i = dynamic_cast<const KernelMore<KernelTuples>*>(impl.get());
if (i && i->UseMe(attr)) {
return i->GetFunc();
auto i = dynamic_cast<const KernelMore<KernelTuple>*>(impl.get());
if (i && i->CanBeUsed(attr)) {
res.emplace_back(i);
}
}
}
// The last implementation should be reference function on CPUPlace.
return GetRefer<KT, KernelTuples>();
auto ref = GetReferKernel<KernelTuple>();
PADDLE_ENFORCE(ref != nullptr, "Refer Kernel can not be empty.");
res.emplace_back(ref);
return res;
}
template <typename KernelTuple, typename PlaceType = platform::CPUPlace>
std::vector<std::pair<std::string, typename KernelTuple::func_type>>
GetAllCandidateFuncsWithTypes(const typename KernelTuple::attr_type& attr) {
using Func = typename KernelTuple::func_type;
auto kers = GetAllCandidateKernels<KernelTuple, PlaceType>(attr);
std::vector<std::pair<std::string, Func>> res;
for (auto k : kers) {
std::string name = k->ImplType();
if (name == "JitCode") {
auto i = dynamic_cast<const GenBase*>(k);
PADDLE_ENFORCE(i, "jitcode kernel cast can not fail.");
res.emplace_back(std::make_pair(name, i->template getCode<Func>()));
} else {
auto i = dynamic_cast<const KernelMore<KernelTuple>*>(k);
PADDLE_ENFORCE(i, "kernel cast can not fail.");
res.emplace_back(std::make_pair(name, i->GetFunc()));
}
}
return res;
}
template <typename KernelTuple, typename PlaceType = platform::CPUPlace>
std::vector<typename KernelTuple::func_type> GetAllCandidateFuncs(
const typename KernelTuple::attr_type& attr) {
auto funcs = GetAllCandidateFuncsWithTypes<KernelTuple, PlaceType>(attr);
std::vector<typename KernelTuple::func_type> res;
for (auto& i : funcs) {
res.emplace_back(i.second);
}
return res;
}
template <typename KernelTuple, typename PlaceType = platform::CPUPlace>
typename KernelTuple::func_type GetDefaultBestFunc(
const typename KernelTuple::attr_type& attr) {
auto funcs = GetAllCandidateFuncs<KernelTuple, PlaceType>(attr);
PADDLE_ENFORCE_GE(funcs.size(), 1UL);
// Here could do some runtime benchmark of this attr and return the best one.
// But yet just get the first one as the default best one,
// which is searched in order and tuned by offline.
return funcs[0];
}
template <KernelType KT, typename KernelTuples, typename PlaceType>
template <typename KernelTuple, typename PlaceType>
class KernelFuncs {
public:
KernelFuncs() = default;
static KernelFuncs& Cache() {
static thread_local KernelFuncs<KT, KernelTuples, PlaceType> g_func_cache;
static thread_local KernelFuncs<KernelTuple, PlaceType> g_func_cache;
return g_func_cache;
}
bool Has(int key) const { return funcs_.find(key) != funcs_.end(); }
void Insert(int key, typename KernelTuples::func_type func) {
funcs_.emplace(key, func);
}
typename KernelTuples::func_type At(int key) {
// the exposed interface to use
typename KernelTuple::func_type At(
const typename KernelTuple::attr_type& attr) {
// Maybe here is not good enough, not all kernels should have jitcode
int64_t key = JitCodeKey<typename KernelTuple::attr_type>(attr);
if (Has(key)) {
return funcs_.at(key);
}
auto func = Get<KT, KernelTuples, PlaceType>(key);
// If do not have this attr in cache then get the default best
auto func = GetDefaultBestFunc<KernelTuple, PlaceType>(attr);
Insert(key, func);
return func;
}
typename KernelTuple::func_type operator[](
const typename KernelTuple::attr_type& attr) {
return At(attr);
}
protected:
bool Has(int64_t key) const { return funcs_.find(key) != funcs_.end(); }
void Insert(int64_t key, typename KernelTuple::func_type func) {
funcs_.emplace(key, func);
}
private:
std::unordered_map<int, typename KernelTuples::func_type> funcs_;
std::unordered_map<int64_t, typename KernelTuple::func_type> funcs_;
DISABLE_COPY_AND_ASSIGN(KernelFuncs);
};
......
......@@ -62,26 +62,55 @@ typedef enum {
kSqrt,
} SeqPoolType;
// x, y, z, n
template <typename T>
struct XYZNTuples {
struct XYZNTuple {
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(const T*, const T*, T*, int);
};
// a, x, y, n
template <typename T>
struct AXYNTuples : public XYZNTuples<T> {};
struct AXYNTuple : public XYZNTuple<T> {};
// x, y, n
template <typename T>
struct XYNTuples {
struct XYNTuple {
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(const T*, T*, int);
};
// x, return and int
// x, returned value, n
template <typename T>
struct XRNTuples : public XYNTuples<T> {};
struct XRNTuple : public XYNTuple<T> {};
#define DECLARE_KERNELTUPLE(kernel_tuple, type) \
template <typename T> \
struct type##Tuple : public kernel_tuple<T> { \
static constexpr KernelType kernel_type = k##type; \
}
// Tuple should be corresponding to the KernelType
DECLARE_KERNELTUPLE(XYZNTuple, VMul);
DECLARE_KERNELTUPLE(XYZNTuple, VAdd);
DECLARE_KERNELTUPLE(XYZNTuple, VAddRelu);
DECLARE_KERNELTUPLE(XYZNTuple, VSub);
DECLARE_KERNELTUPLE(AXYNTuple, VScal);
DECLARE_KERNELTUPLE(AXYNTuple, VAddBias);
DECLARE_KERNELTUPLE(XYNTuple, VRelu);
DECLARE_KERNELTUPLE(XYNTuple, VIdentity);
DECLARE_KERNELTUPLE(XYNTuple, VSquare);
DECLARE_KERNELTUPLE(XYNTuple, VExp);
DECLARE_KERNELTUPLE(XYNTuple, VSigmoid);
DECLARE_KERNELTUPLE(XYNTuple, VTanh);
DECLARE_KERNELTUPLE(XYNTuple, VCopy);
DECLARE_KERNELTUPLE(XRNTuple, HMax);
DECLARE_KERNELTUPLE(XRNTuple, HSum);
typedef struct {
void* gates; // gates: x_ch, x_ih, x_fh, x_oh
......@@ -122,21 +151,31 @@ typedef struct rnn_attr_s gru_attr_t;
typedef struct lstm_attr_s lstm_attr_t;
template <typename T>
struct LSTMTuples {
struct LSTMTuple {
typedef T data_type;
typedef lstm_attr_t attr_type;
typedef void (*func_type)(lstm_t*, const lstm_attr_t*);
};
template <typename T>
struct GRUTuples {
struct GRUTuple {
typedef T data_type;
typedef gru_attr_t attr_type;
typedef void (*func_type)(gru_t*, const gru_attr_t*);
};
DECLARE_KERNELTUPLE(LSTMTuple, LSTMCtHt);
DECLARE_KERNELTUPLE(LSTMTuple, LSTMC1H1);
DECLARE_KERNELTUPLE(GRUTuple, GRUH1);
DECLARE_KERNELTUPLE(GRUTuple, GRUHtPart1);
DECLARE_KERNELTUPLE(GRUTuple, GRUHtPart2);
#undef DECLARE_KERNELTUPLE
template <typename T>
struct VBroadcastTuples {
struct VBroadcastTuple {
static constexpr KernelType kernel_type = kVBroadcast;
typedef T data_type;
typedef int64_t attr_type;
typedef void (*func_type)(const T*, T*, int64_t, int64_t);
......@@ -151,7 +190,8 @@ typedef struct seq_pool_attr_s {
} seq_pool_attr_t;
template <typename T>
struct SeqPoolTuples {
struct SeqPoolTuple {
static constexpr KernelType kernel_type = kSeqPool;
typedef T data_type;
typedef seq_pool_attr_t attr_type;
typedef void (*func_type)(const T*, T*, const seq_pool_attr_t*);
......@@ -176,7 +216,8 @@ typedef struct emb_seq_pool_attr_s {
} emb_seq_pool_attr_t;
template <typename T>
struct EmbSeqPoolTuples {
struct EmbSeqPoolTuple {
static constexpr KernelType kernel_type = kEmbSeqPool;
typedef T data_type;
typedef emb_seq_pool_attr_t attr_type;
typedef void (*func_type)(const T*, const int64_t*, T*,
......@@ -198,7 +239,8 @@ typedef struct sgd_attr_s {
} sgd_attr_t;
template <typename T>
struct SgdTuples {
struct SgdTuple {
static constexpr KernelType kernel_type = kSgd;
typedef T data_type;
typedef sgd_attr_t attr_type;
typedef void (*func_type)(const T*, const T*, const T*, const int64_t*, T*,
......@@ -214,21 +256,24 @@ typedef struct matmul_attr_s {
} matmul_attr_t;
template <typename T>
struct MatMulTuples {
struct MatMulTuple {
static constexpr KernelType kernel_type = kMatMul;
typedef T data_type;
typedef matmul_attr_t attr_type;
typedef void (*func_type)(const T*, const T*, T*, const matmul_attr_t*);
};
template <typename T>
struct CRFDecodingTuples {
struct CRFDecodingTuple {
static constexpr KernelType kernel_type = kCRFDecoding;
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(const int, const T*, const T*, T*, int*, int);
};
template <typename T>
struct LayerNormTuples {
struct LayerNormTuple {
static constexpr KernelType kernel_type = kLayerNorm;
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(T*, T*, T*, T*, const T*, const T*, int,
......@@ -236,7 +281,8 @@ struct LayerNormTuples {
};
template <typename T>
struct SoftmaxTuples {
struct SoftmaxTuple {
static constexpr KernelType kernel_type = kSoftmax;
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(const T*, T*, int, int);
......@@ -244,7 +290,8 @@ struct SoftmaxTuples {
// nChw16c = nChw16c .* NC
template <typename T>
struct NCHW16CMulNCTuples {
struct NCHW16CMulNCTuple {
static constexpr KernelType kernel_type = kNCHW16CMulNC;
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(const T*, const T*, T*, int, int);
......@@ -255,28 +302,29 @@ class Kernel {
public:
Kernel() = default;
virtual ~Kernel() = default;
virtual const char* ImplType() const = 0;
DISABLE_COPY_AND_ASSIGN(Kernel);
};
template <typename KernelTuples>
template <typename KernelTuple>
class KernelMore : public Kernel {
public:
using T = typename KernelTuples::data_type;
using Func = typename KernelTuples::func_type;
using Attr = typename KernelTuples::attr_type;
using T = typename KernelTuple::data_type;
using Func = typename KernelTuple::func_type;
using Attr = typename KernelTuple::attr_type;
virtual Func GetFunc() const { return func; }
virtual bool UseMe(const Attr& attr) const = 0;
virtual const char* ImplType() const = 0;
// specify this kernel can be used, means it should not fail if use it.
virtual bool CanBeUsed(const Attr& attr) const = 0;
protected:
Func func{nullptr};
};
template <typename KernelTuples>
class ReferKernel : public KernelMore<KernelTuples> {
template <typename KernelTuple>
class ReferKernel : public KernelMore<KernelTuple> {
public:
// Refer code can always be used
bool UseMe(const typename KernelTuples::attr_type& attr) const override {
bool CanBeUsed(const typename KernelTuple::attr_type& attr) const override {
return true;
}
const char* ImplType() const override { return "Refer"; }
......
......@@ -13,6 +13,7 @@
* limitations under the License. */
#include "paddle/fluid/operators/jit/kernel_key.h"
#include <xxhash.h> // XXH64: 13.8 GB/s
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
......@@ -20,71 +21,46 @@ namespace operators {
namespace jit {
template <>
size_t JitCodeKey<int>(const int& d) {
int64_t JitCodeKey<int>(const int& d) {
return d;
}
template <>
size_t JitCodeKey<int64_t>(const int64_t& d) {
int64_t JitCodeKey<int64_t>(const int64_t& d) {
return d;
}
// TODO(TJ): refine and benchmark JitCodeKey generatation
constexpr int act_type_shift = 3; // suppot 2^3 act types
static inline int act_type_convert(KernelType type) {
if (type == kVIdentity) {
return 0;
} else if (type == kVExp) {
return 1;
} else if (type == kVRelu) {
return 2;
} else if (type == kVSigmoid) {
return 3;
} else if (type == kVTanh) {
return 4;
}
PADDLE_THROW("Unsupported act type %d", type);
return 0;
}
template <>
size_t JitCodeKey<lstm_attr_t>(const lstm_attr_t& attr) {
size_t key = attr.d;
int gate_key = act_type_convert(attr.act_gate) << 1;
int cand_key = act_type_convert(attr.act_cand) << (1 + act_type_shift);
int cell_key = act_type_convert(attr.act_cell) << (1 + act_type_shift * 2);
return (key << (1 + act_type_shift * 3)) + gate_key + cand_key + cell_key +
attr.use_peephole;
int64_t JitCodeKey<gru_attr_t>(const gru_attr_t& attr) {
return XXH64(&attr, sizeof(gru_attr_t), 0);
}
template <>
size_t JitCodeKey<gru_attr_t>(const gru_attr_t& attr) {
size_t key = attr.d;
return (key << (act_type_shift * 2)) + act_type_convert(attr.act_gate) +
(act_type_convert(attr.act_cand) << act_type_shift);
int64_t JitCodeKey<lstm_attr_t>(const lstm_attr_t& attr) {
int keys[5] = {
attr.d, static_cast<int>(attr.act_gate), static_cast<int>(attr.act_cand),
static_cast<int>(attr.act_cell), static_cast<int>(attr.use_peephole)};
return XXH64(keys, sizeof(int) * 5, 0);
}
template <>
size_t JitCodeKey<seq_pool_attr_t>(const seq_pool_attr_t& attr) {
size_t key = attr.w;
constexpr int pool_type_shift = 3;
return (key << pool_type_shift) + static_cast<int>(attr.type);
int64_t JitCodeKey<seq_pool_attr_t>(const seq_pool_attr_t& attr) {
int keys[2] = {attr.w, static_cast<int>(attr.type)};
return XXH64(keys, sizeof(int) * 2, 0);
}
template <>
size_t JitCodeKey<matmul_attr_t>(const matmul_attr_t& attr) {
size_t key = attr.m;
constexpr int shift = 21;
return (key << shift * 2) + ((static_cast<size_t>(attr.n)) << shift) + attr.k;
int64_t JitCodeKey<matmul_attr_t>(const matmul_attr_t& attr) {
return XXH64(&attr, sizeof(int) * 3, 0); // m, n, k
}
template <>
size_t JitCodeKey<emb_seq_pool_attr_t>(const emb_seq_pool_attr_t& attr) {
int64_t JitCodeKey<emb_seq_pool_attr_t>(const emb_seq_pool_attr_t& attr) {
return attr.table_width;
}
template <>
size_t JitCodeKey<sgd_attr_t>(const sgd_attr_t& attr) {
int64_t JitCodeKey<sgd_attr_t>(const sgd_attr_t& attr) {
return attr.grad_width;
}
......
......@@ -46,7 +46,7 @@ struct KernelKey {
// Every JitCode should have a method to get the key from attribution
template <typename Attr>
size_t JitCodeKey(const Attr& attr);
int64_t JitCodeKey(const Attr& attr);
} // namespace jit
} // namespace operators
......
......@@ -17,6 +17,7 @@
#include <memory> // for unique_ptr
#include <string>
#include <unordered_map>
#include <utility> // for move
#include <vector>
#include "paddle/fluid/operators/jit/gen_base.h"
#include "paddle/fluid/operators/jit/kernel_base.h"
......@@ -30,7 +31,7 @@ namespace jit {
template <KernelType KT>
class JitCodePool {
typedef std::unique_ptr<GenBase> GenBasePtr;
typedef std::unordered_map<size_t, GenBasePtr> JitCodeMap;
typedef std::unordered_map<int64_t, GenBasePtr> JitCodeMap;
public:
JitCodePool() = default;
......@@ -41,9 +42,9 @@ class JitCodePool {
const JitCodeMap& AllKernels() { return codes_; }
bool Has(size_t key) const { return codes_.find(key) != codes_.end(); }
bool Has(int64_t key) const { return codes_.find(key) != codes_.end(); }
void Insert(size_t key, GenBasePtr value) {
void Insert(int64_t key, GenBasePtr value) {
codes_.emplace(key, std::move(value));
}
......
......@@ -161,7 +161,7 @@ void CRFDecoding(const int seq_len, const float* x, const float* w,
}
}
bool CRFDecodingKernel::UseMe(const int& d) const {
bool CRFDecodingKernel::CanBeUsed(const int& d) const {
#ifdef __AVX512F__
constexpr int block = ZMM_FLOAT_BLOCK;
#else
......
......@@ -26,11 +26,11 @@ namespace intrinsic {
void CRFDecoding(const int seq_len, const float* x, const float* w,
float* alpha, int* track, int tag_num);
class CRFDecodingKernel : public KernelMore<CRFDecodingTuples<float>> {
class CRFDecodingKernel : public KernelMore<CRFDecodingTuple<float>> {
public:
CRFDecodingKernel() { this->func = CRFDecoding; }
bool UseMe(
const typename CRFDecodingTuples<float>::attr_type&) const override;
bool CanBeUsed(
const typename CRFDecodingTuple<float>::attr_type&) const override;
const char* ImplType() const override { return "Intrinsic"; }
};
......
......@@ -153,7 +153,7 @@ void LayerNorm(float* x, float* out, float* mean, float* var,
}
}
bool LayerNormKernel::UseMe(const int& d) const {
bool LayerNormKernel::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx) && d >= YMM_FLOAT_BLOCK;
}
......
......@@ -27,10 +27,11 @@ void LayerNorm(float* x, float* out, float* mean, float* var,
const float* scale, const float* bias, int height,
const float epsilon, int right);
class LayerNormKernel : public KernelMore<LayerNormTuples<float>> {
class LayerNormKernel : public KernelMore<LayerNormTuple<float>> {
public:
LayerNormKernel() { this->func = LayerNorm; }
bool UseMe(const typename LayerNormTuples<float>::attr_type&) const override;
bool CanBeUsed(
const typename LayerNormTuple<float>::attr_type&) const override;
const char* ImplType() const override { return "Intrinsic"; }
};
......
......@@ -23,6 +23,8 @@ namespace jit {
namespace more {
namespace mix {
using CPUPlace = platform::CPUPlace;
void VSigmoid(const T* x, T* y, int n) {
const float min = SIGMOID_THRESHOLD_MIN;
const float max = SIGMOID_THRESHOLD_MAX;
......@@ -30,7 +32,7 @@ void VSigmoid(const T* x, T* y, int n) {
y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
y[i] = static_cast<T>(0) - y[i];
}
auto compute = Get<KernelType::kVExp, XYNTuples<T>, platform::CPUPlace>(n);
auto compute = KernelFuncs<VExpTuple<T>, CPUPlace>::Cache().At(n);
compute(y, y, n);
for (int i = 0; i < n; ++i) {
y[i] = static_cast<T>(1) / (static_cast<T>(1) + y[i]);
......@@ -39,9 +41,9 @@ void VSigmoid(const T* x, T* y, int n) {
void VTanh(const T* x, T* y, int n) {
const T a = 2, b = -1;
auto compute_scal = Get<kVScal, AXYNTuples<T>, platform::CPUPlace>(n);
auto compute_addbias = Get<kVAddBias, AXYNTuples<T>, platform::CPUPlace>(n);
auto compute_sigmoid = Get<kVSigmoid, XYNTuples<T>, platform::CPUPlace>(n);
auto compute_scal = KernelFuncs<VScalTuple<T>, CPUPlace>::Cache().At(n);
auto compute_addbias = KernelFuncs<VAddBiasTuple<T>, CPUPlace>::Cache().At(n);
auto compute_sigmoid = KernelFuncs<VSigmoidTuple<T>, CPUPlace>::Cache().At(n);
compute_scal(&a, x, y, n);
compute_sigmoid(y, y, n);
compute_scal(&a, y, y, n);
......@@ -49,16 +51,12 @@ void VTanh(const T* x, T* y, int n) {
}
void Softmax(const T* x, T* y, int n, int bs) {
auto compute_hmax =
KernelFuncs<kHMax, XRNTuples<T>, platform::CPUPlace>::Cache().At(n);
auto compute_hsum =
KernelFuncs<kHSum, XRNTuples<T>, platform::CPUPlace>::Cache().At(n);
auto compute_vscal =
KernelFuncs<kVScal, AXYNTuples<T>, platform::CPUPlace>::Cache().At(n);
auto compute_hmax = KernelFuncs<HMaxTuple<T>, CPUPlace>::Cache().At(n);
auto compute_hsum = KernelFuncs<HSumTuple<T>, CPUPlace>::Cache().At(n);
auto compute_vscal = KernelFuncs<VScalTuple<T>, CPUPlace>::Cache().At(n);
auto compute_vaddbias =
KernelFuncs<kVAddBias, AXYNTuples<T>, platform::CPUPlace>::Cache().At(n);
auto compute_vexp =
KernelFuncs<kVExp, XYNTuples<T>, platform::CPUPlace>::Cache().At(n);
KernelFuncs<VAddBiasTuple<T>, CPUPlace>::Cache().At(n);
auto compute_vexp = KernelFuncs<VExpTuple<T>, CPUPlace>::Cache().At(n);
for (int i = 0; i < bs; ++i) {
T scalar;
......@@ -76,13 +74,13 @@ void Softmax(const T* x, T* y, int n, int bs) {
void (*getActFunc(KernelType type, int d))(const T*, T*, int) { // NOLINT
if (type == kVSigmoid) {
return Get<kVSigmoid, XYNTuples<T>, platform::CPUPlace>(d);
return KernelFuncs<VSigmoidTuple<T>, CPUPlace>::Cache().At(d);
} else if (type == kVRelu) {
return Get<kVRelu, XYNTuples<T>, platform::CPUPlace>(d);
return KernelFuncs<VReluTuple<T>, CPUPlace>::Cache().At(d);
} else if (type == kVTanh) {
return Get<kVTanh, XYNTuples<T>, platform::CPUPlace>(d);
return KernelFuncs<VTanhTuple<T>, CPUPlace>::Cache().At(d);
} else if (type == kVIdentity) {
return Get<kVIdentity, XYNTuples<T>, platform::CPUPlace>(d);
return KernelFuncs<VIdentityTuple<T>, CPUPlace>::Cache().At(d);
}
PADDLE_THROW("Not support type: %s", type);
return nullptr;
......@@ -98,9 +96,9 @@ void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr) {
const int d = attr->d;
const int d2 = d * 2;
const int d3 = d * 3;
auto vmul_d = Get<kVMul, XYZNTuples<T>, platform::CPUPlace>(d);
auto vadd_d = Get<kVAdd, XYZNTuples<T>, platform::CPUPlace>(d);
auto vadd_d2 = Get<kVAdd, XYZNTuples<T>, platform::CPUPlace>(d2);
auto vmul_d = KernelFuncs<VMulTuple<T>, CPUPlace>::Cache().At(d);
auto vadd_d = KernelFuncs<VAddTuple<T>, CPUPlace>::Cache().At(d);
auto vadd_d2 = KernelFuncs<VAddTuple<T>, CPUPlace>::Cache().At(d2);
auto act_gate_d = getActFunc(attr->act_gate, d);
auto act_gate_d2 = getActFunc(attr->act_gate, d2);
auto act_gate_d3 = getActFunc(attr->act_gate, d3);
......@@ -140,8 +138,8 @@ void LSTMC1H1(lstm_t* step, const lstm_attr_t* attr) {
int d = attr->d;
int d2 = d * 2;
int d3 = d * 3;
auto vmul_d = Get<kVMul, XYZNTuples<T>, platform::CPUPlace>(d);
auto vadd_d = Get<kVAdd, XYZNTuples<T>, platform::CPUPlace>(d);
auto vmul_d = KernelFuncs<VMulTuple<T>, CPUPlace>::Cache().At(d);
auto vadd_d = KernelFuncs<VAddTuple<T>, CPUPlace>::Cache().At(d);
auto act_gate_d = getActFunc(attr->act_gate, d);
auto act_cand_d = getActFunc(attr->act_cand, d);
auto act_cell_d = getActFunc(attr->act_cell, d);
......@@ -169,7 +167,7 @@ void GRUH1(gru_t* step, const gru_attr_t* attr) {
int d2 = d * 2;
auto act_gate = getActFunc(attr->act_gate, d);
auto act_cand = getActFunc(attr->act_cand, d);
auto vmul_d = Get<kVMul, XYZNTuples<T>, platform::CPUPlace>(d);
auto vmul_d = KernelFuncs<VMulTuple<T>, CPUPlace>::Cache().At(d);
act_gate(gates, gates, d);
act_cand(gates + d2, gates + d2, d);
vmul_d(gates, gates + d2, ht, d);
......@@ -182,7 +180,7 @@ void GRUHtPart1(gru_t* step, const gru_attr_t* attr) {
T* ht = reinterpret_cast<T*>(step->ht);
const T* ht_1 = reinterpret_cast<const T*>(step->ht_1);
auto act_gate = getActFunc(attr->act_gate, attr->d);
auto vmul_d = Get<kVMul, XYZNTuples<T>, platform::CPUPlace>(attr->d);
auto vmul_d = KernelFuncs<VMulTuple<T>, CPUPlace>::Cache().At(attr->d);
act_gate(gates + attr->d, gates + attr->d, attr->d);
vmul_d(ht_1, gates + attr->d, ht, attr->d);
}
......@@ -206,21 +204,21 @@ void GRUHtPart2(gru_t* step, const gru_attr_t* attr) {
}
// TODO(TJ): tuning me
bool VSigmoidKernel::UseMe(const int& d) const { return true; }
bool VSigmoidKernel::CanBeUsed(const int& d) const { return true; }
bool VTanhKernel::UseMe(const int& d) const { return true; }
bool VTanhKernel::CanBeUsed(const int& d) const { return true; }
bool SoftmaxKernel::UseMe(const int& d) const { return true; }
bool SoftmaxKernel::CanBeUsed(const int& d) const { return true; }
bool LSTMCtHtKernel::UseMe(const lstm_attr_t& attr) const { return true; }
bool LSTMCtHtKernel::CanBeUsed(const lstm_attr_t& attr) const { return true; }
bool LSTMC1H1Kernel::UseMe(const lstm_attr_t& attr) const { return true; }
bool LSTMC1H1Kernel::CanBeUsed(const lstm_attr_t& attr) const { return true; }
bool GRUH1Kernel::UseMe(const gru_attr_t& attr) const { return true; }
bool GRUH1Kernel::CanBeUsed(const gru_attr_t& attr) const { return true; }
bool GRUHtPart1Kernel::UseMe(const gru_attr_t& attr) const { return true; }
bool GRUHtPart1Kernel::CanBeUsed(const gru_attr_t& attr) const { return true; }
bool GRUHtPart2Kernel::UseMe(const gru_attr_t& attr) const { return true; }
bool GRUHtPart2Kernel::CanBeUsed(const gru_attr_t& attr) const { return true; }
} // namespace mix
} // namespace more
......@@ -230,16 +228,16 @@ bool GRUHtPart2Kernel::UseMe(const gru_attr_t& attr) const { return true; }
namespace mix = paddle::operators::jit::more::mix;
#define REGISTER_MORE_KERNEL(key, func) \
REGISTER_JITKERNEL_MORE(key, mix, mix::func##Kernel)
REGISTER_MORE_KERNEL(kVSigmoid, VSigmoid);
REGISTER_MORE_KERNEL(kVTanh, VTanh);
REGISTER_MORE_KERNEL(kSoftmax, Softmax);
REGISTER_MORE_KERNEL(kLSTMCtHt, LSTMCtHt);
REGISTER_MORE_KERNEL(kLSTMC1H1, LSTMC1H1);
REGISTER_MORE_KERNEL(kGRUH1, GRUH1);
REGISTER_MORE_KERNEL(kGRUHtPart1, GRUHtPart1);
REGISTER_MORE_KERNEL(kGRUHtPart2, GRUHtPart2);
#define REGISTER_MORE_KERNEL(func) \
REGISTER_JITKERNEL_MORE(k##func, mix, mix::func##Kernel)
REGISTER_MORE_KERNEL(VSigmoid);
REGISTER_MORE_KERNEL(VTanh);
REGISTER_MORE_KERNEL(Softmax);
REGISTER_MORE_KERNEL(LSTMCtHt);
REGISTER_MORE_KERNEL(LSTMC1H1);
REGISTER_MORE_KERNEL(GRUH1);
REGISTER_MORE_KERNEL(GRUHtPart1);
REGISTER_MORE_KERNEL(GRUHtPart2);
#undef REGISTER_MORE_KERNEL
......@@ -34,27 +34,27 @@ void GRUH1(gru_t* step, const gru_attr_t* attr);
void GRUHtPart1(gru_t* step, const gru_attr_t* attr);
void GRUHtPart2(gru_t* step, const gru_attr_t* attr);
#define DECLARE_MORE_KERNEL(name, tuples) \
class name##Kernel : public KernelMore<tuples<T>> { \
#define DECLARE_MORE_KERNEL(name) \
class name##Kernel : public KernelMore<name##Tuple<T>> { \
public: \
name##Kernel() { this->func = name; } \
bool UseMe(const typename tuples<T>::attr_type&) const override; \
bool CanBeUsed(const typename name##Tuple<T>::attr_type&) const override; \
const char* ImplType() const override { return "Mixed"; } \
}
// XYN
DECLARE_MORE_KERNEL(VSigmoid, XYNTuples);
DECLARE_MORE_KERNEL(VTanh, XYNTuples);
DECLARE_MORE_KERNEL(VSigmoid);
DECLARE_MORE_KERNEL(VTanh);
// XRN
DECLARE_MORE_KERNEL(Softmax, SoftmaxTuples);
DECLARE_MORE_KERNEL(Softmax);
DECLARE_MORE_KERNEL(LSTMCtHt, LSTMTuples);
DECLARE_MORE_KERNEL(LSTMC1H1, LSTMTuples);
DECLARE_MORE_KERNEL(LSTMCtHt);
DECLARE_MORE_KERNEL(LSTMC1H1);
DECLARE_MORE_KERNEL(GRUH1, GRUTuples);
DECLARE_MORE_KERNEL(GRUHtPart1, GRUTuples);
DECLARE_MORE_KERNEL(GRUHtPart2, GRUTuples);
DECLARE_MORE_KERNEL(GRUH1);
DECLARE_MORE_KERNEL(GRUHtPart1);
DECLARE_MORE_KERNEL(GRUHtPart2);
#undef DECLARE_MORE_KERNEL
......
......@@ -130,104 +130,105 @@ void ASum<double>(const double* x, double* res, int n) {
// TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512
template <>
bool VMulKernel<float>::UseMe(const int& d) const {
bool VMulKernel<float>::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx512f) && d > 512;
}
template <>
bool VAddKernel<float>::UseMe(const int& d) const {
bool VAddKernel<float>::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx) && d > 512;
}
template <>
bool VScalKernel<float>::UseMe(const int& d) const {
bool VScalKernel<float>::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx512f) && d > 512;
}
template <>
bool VExpKernel<float>::UseMe(const int& d) const {
bool VExpKernel<float>::CanBeUsed(const int& d) const {
return d > 7;
}
template <>
bool VSquareKernel<float>::UseMe(const int& d) const {
bool VSquareKernel<float>::CanBeUsed(const int& d) const {
return d > 7;
}
template <>
bool VCopyKernel<float>::UseMe(const int& d) const {
bool VCopyKernel<float>::CanBeUsed(const int& d) const {
return d > 15;
}
template <>
bool VBroadcastKernel<float>::UseMe(const int64_t& d) const {
bool VBroadcastKernel<float>::CanBeUsed(const int64_t& d) const {
return d > 127;
}
template <>
bool VBroadcastKernel<double>::UseMe(const int64_t& attr) const {
bool VBroadcastKernel<double>::CanBeUsed(const int64_t& attr) const {
return true;
}
template <>
bool VSigmoidKernel<float>::UseMe(const int& d) const {
bool VSigmoidKernel<float>::CanBeUsed(const int& d) const {
return d > 7;
}
template <>
bool VTanhKernel<float>::UseMe(const int& d) const {
bool VTanhKernel<float>::CanBeUsed(const int& d) const {
return d > 7;
}
template <>
bool SeqPoolKernel<float>::UseMe(const seq_pool_attr_t& attr) const {
bool SeqPoolKernel<float>::CanBeUsed(const seq_pool_attr_t& attr) const {
return true;
}
template <>
bool SeqPoolKernel<double>::UseMe(const seq_pool_attr_t& attr) const {
bool SeqPoolKernel<double>::CanBeUsed(const seq_pool_attr_t& attr) const {
return true;
}
template <>
bool EmbSeqPoolKernel<float>::UseMe(const emb_seq_pool_attr_t& attr) const {
bool EmbSeqPoolKernel<float>::CanBeUsed(const emb_seq_pool_attr_t& attr) const {
return true;
}
template <>
bool EmbSeqPoolKernel<double>::UseMe(const emb_seq_pool_attr_t& attr) const {
bool EmbSeqPoolKernel<double>::CanBeUsed(
const emb_seq_pool_attr_t& attr) const {
return true;
}
template <>
bool SgdKernel<float>::UseMe(const sgd_attr_t& attr) const {
bool SgdKernel<float>::CanBeUsed(const sgd_attr_t& attr) const {
return true;
}
template <>
bool SgdKernel<double>::UseMe(const sgd_attr_t& attr) const {
bool SgdKernel<double>::CanBeUsed(const sgd_attr_t& attr) const {
return true;
}
template <>
bool MatMulKernel<float>::UseMe(const matmul_attr_t& attr) const {
bool MatMulKernel<float>::CanBeUsed(const matmul_attr_t& attr) const {
return platform::MayIUse(platform::avx);
}
template <>
bool MatMulKernel<double>::UseMe(const matmul_attr_t& attr) const {
bool MatMulKernel<double>::CanBeUsed(const matmul_attr_t& attr) const {
return true;
}
template <>
bool SoftmaxKernel<float>::UseMe(const int& d) const {
bool SoftmaxKernel<float>::CanBeUsed(const int& d) const {
// tuned on avx2
return platform::MayIUse(platform::avx) && d < 60;
}
#define AWALYS_USE_ME_WITH_DOUBLE(func) \
template <> \
bool func##Kernel<double>::UseMe(const int& d) const { \
bool func##Kernel<double>::CanBeUsed(const int& d) const { \
return true; \
}
......@@ -250,23 +251,23 @@ AWALYS_USE_ME_WITH_DOUBLE(Softmax);
namespace mkl = paddle::operators::jit::more::mkl;
#define REGISTER_MKL_KERNEL(key, func) \
REGISTER_JITKERNEL_MORE(key, mkl, mkl::func##Kernel<float>, \
#define REGISTER_MKL_KERNEL(func) \
REGISTER_JITKERNEL_MORE(k##func, mkl, mkl::func##Kernel<float>, \
mkl::func##Kernel<double>)
REGISTER_MKL_KERNEL(kMatMul, MatMul);
REGISTER_MKL_KERNEL(kVMul, VMul);
REGISTER_MKL_KERNEL(kVAdd, VAdd);
REGISTER_MKL_KERNEL(kVScal, VScal);
REGISTER_MKL_KERNEL(kVExp, VExp);
REGISTER_MKL_KERNEL(kVSquare, VSquare);
REGISTER_MKL_KERNEL(kVCopy, VCopy);
REGISTER_MKL_KERNEL(kVBroadcast, VBroadcast);
REGISTER_MKL_KERNEL(kVSigmoid, VSigmoid);
REGISTER_MKL_KERNEL(kVTanh, VTanh);
REGISTER_MKL_KERNEL(kSeqPool, SeqPool);
REGISTER_MKL_KERNEL(kEmbSeqPool, EmbSeqPool);
REGISTER_MKL_KERNEL(kSoftmax, Softmax);
REGISTER_MKL_KERNEL(kSgd, Sgd);
REGISTER_MKL_KERNEL(MatMul);
REGISTER_MKL_KERNEL(VMul);
REGISTER_MKL_KERNEL(VAdd);
REGISTER_MKL_KERNEL(VScal);
REGISTER_MKL_KERNEL(VExp);
REGISTER_MKL_KERNEL(VSquare);
REGISTER_MKL_KERNEL(VCopy);
REGISTER_MKL_KERNEL(VBroadcast);
REGISTER_MKL_KERNEL(VSigmoid);
REGISTER_MKL_KERNEL(VTanh);
REGISTER_MKL_KERNEL(SeqPool);
REGISTER_MKL_KERNEL(EmbSeqPool);
REGISTER_MKL_KERNEL(Softmax);
REGISTER_MKL_KERNEL(Sgd);
#undef REGISTER_MKL_KERNEL
......@@ -175,41 +175,38 @@ void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows,
}
}
#define DECLARE_MKL_KERNEL(name, tuples) \
#define DECLARE_MKL_KERNEL(name) \
template <typename T> \
class name##Kernel : public KernelMore<tuples<T>> { \
class name##Kernel : public KernelMore<name##Tuple<T>> { \
public: \
name##Kernel() { this->func = name<T>; } \
bool UseMe(const typename tuples<T>::attr_type&) const override; \
bool CanBeUsed(const typename name##Tuple<T>::attr_type&) const override; \
const char* ImplType() const override { return "MKL"; } \
}
// ABCMNK
DECLARE_MKL_KERNEL(MatMul, MatMulTuples);
DECLARE_MKL_KERNEL(MatMul);
// XYZN
DECLARE_MKL_KERNEL(VMul, XYZNTuples);
DECLARE_MKL_KERNEL(VAdd, XYZNTuples);
DECLARE_MKL_KERNEL(VMul);
DECLARE_MKL_KERNEL(VAdd);
// AXYN
DECLARE_MKL_KERNEL(VScal, AXYNTuples);
DECLARE_MKL_KERNEL(VScal);
// XYN
DECLARE_MKL_KERNEL(VExp, XYNTuples);
DECLARE_MKL_KERNEL(VSigmoid, XYNTuples);
DECLARE_MKL_KERNEL(VTanh, XYNTuples);
DECLARE_MKL_KERNEL(VSquare, XYNTuples);
DECLARE_MKL_KERNEL(VCopy, XYNTuples);
DECLARE_MKL_KERNEL(SeqPool, SeqPoolTuples);
DECLARE_MKL_KERNEL(EmbSeqPool, EmbSeqPoolTuples);
DECLARE_MKL_KERNEL(Softmax, SoftmaxTuples);
DECLARE_MKL_KERNEL(Sgd, SgdTuples);
DECLARE_MKL_KERNEL(VBroadcast, VBroadcastTuples);
DECLARE_MKL_KERNEL(VExp);
DECLARE_MKL_KERNEL(VSigmoid);
DECLARE_MKL_KERNEL(VTanh);
DECLARE_MKL_KERNEL(VSquare);
DECLARE_MKL_KERNEL(VCopy);
// others
DECLARE_MKL_KERNEL(SeqPool);
DECLARE_MKL_KERNEL(EmbSeqPool);
DECLARE_MKL_KERNEL(Softmax);
DECLARE_MKL_KERNEL(Sgd);
DECLARE_MKL_KERNEL(VBroadcast);
#undef DECLARE_MKL_KERNEL
......
......@@ -17,51 +17,43 @@
namespace refer = paddle::operators::jit::refer;
#define REGISTER_REFER_KERNEL(key, func) \
REGISTER_JITKERNEL_REFER(key, refer::func##Kernel<float>, \
#define REGISTER_REFER_KERNEL(func) \
REGISTER_JITKERNEL_REFER(k##func, refer::func##Kernel<float>, \
refer::func##Kernel<double>)
REGISTER_REFER_KERNEL(kVMul, VMul);
REGISTER_REFER_KERNEL(kVAdd, VAdd);
REGISTER_REFER_KERNEL(kVAddRelu, VAddRelu);
REGISTER_REFER_KERNEL(kVSub, VSub);
REGISTER_REFER_KERNEL(kVScal, VScal);
REGISTER_REFER_KERNEL(kVAddBias, VAddBias);
REGISTER_REFER_KERNEL(kVRelu, VRelu);
REGISTER_REFER_KERNEL(kVCopy, VCopy);
REGISTER_REFER_KERNEL(kVIdentity, VIdentity);
REGISTER_REFER_KERNEL(kVSquare, VSquare);
REGISTER_REFER_KERNEL(kVExp, VExp);
REGISTER_REFER_KERNEL(kVSigmoid, VSigmoid);
REGISTER_REFER_KERNEL(kVTanh, VTanh);
REGISTER_REFER_KERNEL(kLSTMCtHt, LSTMCtHt);
REGISTER_REFER_KERNEL(kLSTMC1H1, LSTMC1H1);
REGISTER_REFER_KERNEL(kGRUH1, GRUH1);
REGISTER_REFER_KERNEL(kGRUHtPart1, GRUHtPart1);
REGISTER_REFER_KERNEL(kGRUHtPart2, GRUHtPart2);
REGISTER_REFER_KERNEL(kCRFDecoding, CRFDecoding);
REGISTER_REFER_KERNEL(kLayerNorm, LayerNorm);
REGISTER_REFER_KERNEL(kNCHW16CMulNC, NCHW16CMulNC);
REGISTER_REFER_KERNEL(kSeqPool, SeqPool);
REGISTER_REFER_KERNEL(kMatMul, MatMul);
REGISTER_REFER_KERNEL(kHMax, HMax);
REGISTER_REFER_KERNEL(kHSum, HSum);
REGISTER_REFER_KERNEL(kSoftmax, Softmax);
REGISTER_REFER_KERNEL(kEmbSeqPool, EmbSeqPool);
REGISTER_REFER_KERNEL(kSgd, Sgd);
REGISTER_REFER_KERNEL(kVBroadcast, VBroadcast);
REGISTER_REFER_KERNEL(VMul);
REGISTER_REFER_KERNEL(VAdd);
REGISTER_REFER_KERNEL(VAddRelu);
REGISTER_REFER_KERNEL(VSub);
REGISTER_REFER_KERNEL(VScal);
REGISTER_REFER_KERNEL(VAddBias);
REGISTER_REFER_KERNEL(VRelu);
REGISTER_REFER_KERNEL(VCopy);
REGISTER_REFER_KERNEL(VIdentity);
REGISTER_REFER_KERNEL(VSquare);
REGISTER_REFER_KERNEL(VExp);
REGISTER_REFER_KERNEL(VSigmoid);
REGISTER_REFER_KERNEL(VTanh);
REGISTER_REFER_KERNEL(LSTMCtHt);
REGISTER_REFER_KERNEL(LSTMC1H1);
REGISTER_REFER_KERNEL(GRUH1);
REGISTER_REFER_KERNEL(GRUHtPart1);
REGISTER_REFER_KERNEL(GRUHtPart2);
REGISTER_REFER_KERNEL(CRFDecoding);
REGISTER_REFER_KERNEL(LayerNorm);
REGISTER_REFER_KERNEL(NCHW16CMulNC);
REGISTER_REFER_KERNEL(SeqPool);
REGISTER_REFER_KERNEL(MatMul);
REGISTER_REFER_KERNEL(HMax);
REGISTER_REFER_KERNEL(HSum);
REGISTER_REFER_KERNEL(Softmax);
REGISTER_REFER_KERNEL(EmbSeqPool);
REGISTER_REFER_KERNEL(Sgd);
REGISTER_REFER_KERNEL(VBroadcast);
#undef REGISTER_REFER_KERNEL
......@@ -490,60 +490,54 @@ void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows,
}
}
#define DECLARE_REFER_KERNEL(name, tuples) \
#define DECLARE_REFER_KERNEL(name) \
template <typename T> \
class name##Kernel : public ReferKernel<tuples<T>> { \
class name##Kernel : public ReferKernel<name##Tuple<T>> { \
public: \
name##Kernel() { this->func = name<T>; } \
}
// const T* x, const T* y, T* z, int n
DECLARE_REFER_KERNEL(VMul, XYZNTuples);
DECLARE_REFER_KERNEL(VAdd, XYZNTuples);
DECLARE_REFER_KERNEL(VAddRelu, XYZNTuples);
DECLARE_REFER_KERNEL(VSub, XYZNTuples);
DECLARE_REFER_KERNEL(VMul);
DECLARE_REFER_KERNEL(VAdd);
DECLARE_REFER_KERNEL(VAddRelu);
DECLARE_REFER_KERNEL(VSub);
// const T* a, const T* x, T* y, int n
DECLARE_REFER_KERNEL(VScal, AXYNTuples);
DECLARE_REFER_KERNEL(VAddBias, AXYNTuples);
DECLARE_REFER_KERNEL(VScal);
DECLARE_REFER_KERNEL(VAddBias);
// const T* x, T* y, int n
DECLARE_REFER_KERNEL(VRelu, XYNTuples);
DECLARE_REFER_KERNEL(VIdentity, XYNTuples);
DECLARE_REFER_KERNEL(VExp, XYNTuples);
DECLARE_REFER_KERNEL(VSigmoid, XYNTuples);
DECLARE_REFER_KERNEL(VTanh, XYNTuples);
DECLARE_REFER_KERNEL(VSquare, XYNTuples);
DECLARE_REFER_KERNEL(VCopy, XYNTuples);
DECLARE_REFER_KERNEL(VRelu);
DECLARE_REFER_KERNEL(VIdentity);
DECLARE_REFER_KERNEL(VExp);
DECLARE_REFER_KERNEL(VSigmoid);
DECLARE_REFER_KERNEL(VTanh);
DECLARE_REFER_KERNEL(VSquare);
DECLARE_REFER_KERNEL(VCopy);
// lstm_t*, const lstm_attr_t*
DECLARE_REFER_KERNEL(LSTMCtHt, LSTMTuples);
DECLARE_REFER_KERNEL(LSTMC1H1, LSTMTuples);
DECLARE_REFER_KERNEL(LSTMCtHt);
DECLARE_REFER_KERNEL(LSTMC1H1);
// gru_t*, const gru_attr_t*
DECLARE_REFER_KERNEL(GRUH1, GRUTuples);
DECLARE_REFER_KERNEL(GRUHtPart1, GRUTuples);
DECLARE_REFER_KERNEL(GRUHtPart2, GRUTuples);
DECLARE_REFER_KERNEL(CRFDecoding, CRFDecodingTuples);
DECLARE_REFER_KERNEL(LayerNorm, LayerNormTuples);
DECLARE_REFER_KERNEL(NCHW16CMulNC, NCHW16CMulNCTuples);
DECLARE_REFER_KERNEL(SeqPool, SeqPoolTuples);
DECLARE_REFER_KERNEL(MatMul, MatMulTuples);
DECLARE_REFER_KERNEL(HMax, XRNTuples);
DECLARE_REFER_KERNEL(HSum, XRNTuples);
DECLARE_REFER_KERNEL(Softmax, SoftmaxTuples);
DECLARE_REFER_KERNEL(EmbSeqPool, EmbSeqPoolTuples);
DECLARE_REFER_KERNEL(Sgd, SgdTuples);
DECLARE_REFER_KERNEL(VBroadcast, VBroadcastTuples);
DECLARE_REFER_KERNEL(GRUH1);
DECLARE_REFER_KERNEL(GRUHtPart1);
DECLARE_REFER_KERNEL(GRUHtPart2);
DECLARE_REFER_KERNEL(HMax);
DECLARE_REFER_KERNEL(HSum);
// others
DECLARE_REFER_KERNEL(CRFDecoding);
DECLARE_REFER_KERNEL(LayerNorm);
DECLARE_REFER_KERNEL(NCHW16CMulNC);
DECLARE_REFER_KERNEL(SeqPool);
DECLARE_REFER_KERNEL(MatMul);
DECLARE_REFER_KERNEL(Softmax);
DECLARE_REFER_KERNEL(EmbSeqPool);
DECLARE_REFER_KERNEL(Sgd);
DECLARE_REFER_KERNEL(VBroadcast);
#undef DECLARE_REFER_KERNEL
......
......@@ -17,6 +17,7 @@
#include <memory>
#include <tuple>
#include <type_traits>
#include <utility> // for std::move
#include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/operators/jit/kernel_pool.h"
#include "paddle/fluid/platform/place.h"
......@@ -49,7 +50,7 @@ struct JitKernelRegistrarFunctor<Pool, PlaceType, false, I, KernelImpls...> {
void operator()(KernelType kt) const {
KernelKey kkey(kt, PlaceType());
Pool().Instance().Insert(kkey,
Pool::Instance().Insert(kkey,
std::move(make_unique<const KERNEL_IMPL_TYPE>()));
constexpr auto size = std::tuple_size<std::tuple<KernelImpls...>>::value;
JitKernelRegistrarFunctor<Pool, PlaceType, I + 1 == size, I + 1,
......
此差异已折叠。
......@@ -230,8 +230,8 @@ class LayerNormKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(bias->numel(), right);
auto ker =
jit::Get<jit::kLayerNorm, jit::LayerNormTuples<T>, platform::CPUPlace>(
right);
jit::KernelFuncs<jit::LayerNormTuple<T>, platform::CPUPlace>::Cache()
.At(right);
ker(x.data<T>(), out.data<T>(), mean->data<T>(), var->data<T>(),
scale->data<T>(), bias->data<T>(), static_cast<int>(left),
static_cast<const float>(epsilon), right);
......
......@@ -30,17 +30,16 @@ inline void FCCompute(const BlasT<DeviceContext, T>& blas, const int M,
return;
}
if (relu) {
auto compute = jit::KernelFuncs<jit::kVAddRelu, jit::XYZNTuples<T>,
platform::CPUPlace>::Cache()
.At(N);
auto compute =
jit::KernelFuncs<jit::VAddReluTuple<T>, platform::CPUPlace>::Cache().At(
N);
for (int i = 0; i < M; i++) {
T* dst = Y + i * N;
compute(B, dst, dst, N);
}
} else {
auto compute = jit::KernelFuncs<jit::kVAdd, jit::XYZNTuples<T>,
platform::CPUPlace>::Cache()
.At(N);
auto compute =
jit::KernelFuncs<jit::VAddTuple<T>, platform::CPUPlace>::Cache().At(N);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
......
......@@ -256,8 +256,8 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
static_cast<int>(input.numel() / input.dims()[0]),
jit::SeqPoolType::kSum);
auto seqpool =
jit::Get<jit::kSeqPool, jit::SeqPoolTuples<T>, platform::CPUPlace>(
attr);
jit::KernelFuncs<jit::SeqPoolTuple<T>, platform::CPUPlace>::Cache()
.At(attr);
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
attr.h = static_cast<int>(lod[i + 1] - lod[i]);
seqpool(src, dst, &attr);
......
......@@ -82,8 +82,7 @@ class SoftmaxFunctor<DeviceContext, float, true, enable_if_CPU<DeviceContext>> {
const int kClassDim = 1;
// 2D data. Batch x C
auto compute_softmax =
jit::KernelFuncs<jit::kSoftmax, jit::SoftmaxTuples<float>,
platform::CPUPlace>::Cache()
jit::KernelFuncs<jit::SoftmaxTuple<float>, platform::CPUPlace>::Cache()
.At(in_dims[kClassDim]);
compute_softmax(in_data, out_data, in_dims[kClassDim], in_dims[kBatchDim]);
}
......
......@@ -48,7 +48,8 @@ class SGDOpKernel : public framework::OpKernel<T> {
T *out_data = param_out->mutable_data<T>(ctx.GetPlace());
auto sgd =
jit::Get<jit::kSgd, jit::SgdTuples<T>, platform::CPUPlace>(attr);
jit::KernelFuncs<jit::SgdTuple<T>, platform::CPUPlace>::Cache().At(
attr);
sgd(lr, param_data, grad_data, &rows_idx, out_data, &attr);
} else if (grad_var->IsType<framework::SelectedRows>()) {
// TODO(qijun): In Sparse SGD operator, in-place update is enforced.
......@@ -82,7 +83,8 @@ class SGDOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(attr.grad_width, attr.param_width);
auto sgd =
jit::Get<jit::kSgd, jit::SgdTuples<T>, platform::CPUPlace>(attr);
jit::KernelFuncs<jit::SgdTuple<T>, platform::CPUPlace>::Cache().At(
attr);
sgd(lr, param_data, grad_data, rows_data, out_data, &attr);
} else {
PADDLE_THROW("Unsupported Variable Type of Grad");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册