提交 14a764c9 编写于 作者: T tensor-tang

simplify the jitkernel templates and tests

test=develop
上级 802f362a
......@@ -82,8 +82,8 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> {
Tensor track;
int* track_value =
track.mutable_data<int>(emission_dims, platform::CPUPlace());
auto ker = jit::KernelFuncs<jit::kCRFDecoding, jit::CRFDecodingTuples<T>,
platform::CPUPlace>::Cache()
auto ker =
jit::KernelFuncs<jit::CRFDecodingTuple<T>, platform::CPUPlace>::Cache()
.At(tag_num);
ker(static_cast<int>(seq_len), x, w, alpha_value, track_value, tag_num);
T max_score = -std::numeric_limits<T>::max();
......
......@@ -110,8 +110,7 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
constexpr int simd_width = 16;
int C = c / simd_width;
auto multiply =
jit::KernelFuncs<jit::kNCHW16CMulNC, jit::NCHW16CMulNCTuples<T>,
auto multiply = jit::KernelFuncs<jit::NCHW16CMulNCTuple<T>,
platform::CPUPlace>::Cache()
.At(0);
#pragma omp parallel for collapse(2)
......
......@@ -53,8 +53,7 @@ struct EmbeddingVSumFunctor {
for (size_t i = 0; i != ids_lod.size() - 1; ++i) {
attr.index_height = ids_lod[i + 1] - ids_lod[i];
auto emb_seqpool =
jit::KernelFuncs<jit::kEmbSeqPool, jit::EmbSeqPoolTuples<T>,
platform::CPUPlace>::Cache()
jit::KernelFuncs<jit::EmbSeqPoolTuple<T>, platform::CPUPlace>::Cache()
.At(attr);
emb_seqpool(table, ids + ids_lod[i] * idx_width, output + i * out_width,
&attr);
......@@ -138,8 +137,7 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> {
const T *d_output_data = d_output->data<T>();
auto vbroadcast =
jit::KernelFuncs<jit::kVBroadcast, jit::VBroadcastTuples<T>,
platform::CPUPlace>::Cache()
jit::KernelFuncs<jit::VBroadcastTuple<T>, platform::CPUPlace>::Cache()
.At(out_width);
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]);
......
......@@ -195,14 +195,14 @@ class FusionGRUKernel : public framework::OpKernel<T> {
D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")), \
jit::to_kerneltype(ctx.Attr<std::string>("activation"))); \
jit::gru_t one_step; \
auto ComputeH1 = jit::KernelFuncs<jit::kGRUH1, jit::GRUTuples<T>, \
platform::CPUPlace>::Cache() \
auto ComputeH1 = \
jit::KernelFuncs<jit::GRUH1Tuple<T>, platform::CPUPlace>::Cache().At( \
attr); \
auto ComputeHtPart1 = \
jit::KernelFuncs<jit::GRUHtPart1Tuple<T>, platform::CPUPlace>::Cache() \
.At(attr); \
auto ComputeHtPart1 = jit::KernelFuncs<jit::kGRUHtPart1, jit::GRUTuples<T>, \
platform::CPUPlace>::Cache() \
.At(attr); \
auto ComputeHtPart2 = jit::KernelFuncs<jit::kGRUHtPart2, jit::GRUTuples<T>, \
platform::CPUPlace>::Cache() \
auto ComputeHtPart2 = \
jit::KernelFuncs<jit::GRUHtPart2Tuple<T>, platform::CPUPlace>::Cache() \
.At(attr); \
const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \
......
......@@ -257,12 +257,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
jit::lstm_t one_step; \
one_step.wp = wp_data; \
one_step.checked = checked_cell_data; \
auto ComputeC1H1 = jit::KernelFuncs<jit::kLSTMC1H1, jit::LSTMTuples<T>, \
platform::CPUPlace>::Cache() \
.At(attr); \
auto ComputeCtHt = jit::KernelFuncs<jit::kLSTMCtHt, jit::LSTMTuples<T>, \
platform::CPUPlace>::Cache() \
.At(attr)
auto ComputeC1H1 = \
jit::KernelFuncs<jit::LSTMC1H1Tuple<T>, platform::CPUPlace>::Cache().At( \
attr); \
auto ComputeCtHt = \
jit::KernelFuncs<jit::LSTMCtHtTuple<T>, platform::CPUPlace>::Cache().At( \
attr)
// Wh GEMM
#define GEMM_WH_ADDON(bs, prev, out) \
......
......@@ -81,12 +81,12 @@ void FusionRepeatedFCReluOpMaker::Make() {
template <typename T>
static void fc_relu(const T* x, const T* w, const T* b, T* y,
const jit::matmul_attr_t& attr) {
auto matmul = jit::KernelFuncs<jit::kMatMul, jit::MatMulTuples<T>,
platform::CPUPlace>::Cache()
.At(attr);
auto addbias_relu = jit::KernelFuncs<jit::kVAddRelu, jit::XYZNTuples<T>,
platform::CPUPlace>::Cache()
.At(attr.n);
auto matmul =
jit::KernelFuncs<jit::MatMulTuple<T>, platform::CPUPlace>::Cache().At(
attr);
auto addbias_relu =
jit::KernelFuncs<jit::VAddReluTuple<T>, platform::CPUPlace>::Cache().At(
attr.n);
matmul(x, w, y, &attr);
T* dst = y;
for (int i = 0; i < attr.m; ++i) {
......
......@@ -97,9 +97,9 @@ class FusionSeqPoolConcatKernel : public framework::OpKernel<T> {
} else if (pooltype == "SQRT") {
attr.type = jit::SeqPoolType::kSqrt;
}
auto seqpool = jit::KernelFuncs<jit::kSeqPool, jit::SeqPoolTuples<T>,
platform::CPUPlace>::Cache()
.At(attr);
auto seqpool =
jit::KernelFuncs<jit::SeqPoolTuple<T>, platform::CPUPlace>::Cache().At(
attr);
size_t n = ins.size();
size_t dst_step_size = n * w;
for (size_t i = 0; i < n; ++i) {
......
......@@ -93,24 +93,24 @@ class FusionSquaredMatSubKernel : public framework::OpKernel<T> {
attr.n = y_dims[1];
int o_numel = attr.m * attr.n;
auto vsquare_x = jit::KernelFuncs<jit::kVSquare, jit::XYNTuples<T>,
platform::CPUPlace>::Cache()
.At(attr.m * attr.k);
auto vsquare_y = jit::KernelFuncs<jit::kVSquare, jit::XYNTuples<T>,
platform::CPUPlace>::Cache()
.At(attr.k * attr.n);
auto vsquare_xy = jit::KernelFuncs<jit::kVSquare, jit::XYNTuples<T>,
platform::CPUPlace>::Cache()
.At(o_numel);
auto vsub = jit::KernelFuncs<jit::kVSub, jit::XYZNTuples<T>,
platform::CPUPlace>::Cache()
.At(o_numel);
auto vscal = jit::KernelFuncs<jit::kVScal, jit::AXYNTuples<T>,
platform::CPUPlace>::Cache()
.At(o_numel);
auto matmul = jit::KernelFuncs<jit::kMatMul, jit::MatMulTuples<T>,
platform::CPUPlace>::Cache()
.At(attr);
auto vsquare_x =
jit::KernelFuncs<jit::VSquareTuple<T>, platform::CPUPlace>::Cache().At(
attr.m * attr.k);
auto vsquare_y =
jit::KernelFuncs<jit::VSquareTuple<T>, platform::CPUPlace>::Cache().At(
attr.k * attr.n);
auto vsquare_xy =
jit::KernelFuncs<jit::VSquareTuple<T>, platform::CPUPlace>::Cache().At(
o_numel);
auto vsub =
jit::KernelFuncs<jit::VSubTuple<T>, platform::CPUPlace>::Cache().At(
o_numel);
auto vscal =
jit::KernelFuncs<jit::VScalTuple<T>, platform::CPUPlace>::Cache().At(
o_numel);
auto matmul =
jit::KernelFuncs<jit::MatMulTuple<T>, platform::CPUPlace>::Cache().At(
attr);
const T* x_data = x->data<T>();
const T* y_data = y->data<T>();
......
......@@ -59,8 +59,6 @@ BenchJITKernel* InsertBenchmark(BenchJITKernel* b) {
InsertBenchmark(new BenchJITKernel_##name##_##dtype##_##place##_()); \
void BenchJITKernel_##name##_##dtype##_##place##_::Run()
#define BENCH_FP32_CPU(name) BENCH_JITKERNEL(name, FP32, CPU)
void RUN_ALL_BENCHMARK() {
for (auto p : g_all_benchmarks) {
if (!FLAGS_filter.empty() && FLAGS_filter != p->Name()) {
......@@ -90,11 +88,11 @@ std::vector<int> TestSizes() {
return s;
}
template <typename KernelTuples, typename... Args>
template <typename KernelTuple, typename... Args>
struct BenchFunc {
// return this function avg time
// TODO(TJ): clear cache every time
double operator()(const typename KernelTuples::func_type tgt, Args... args) {
double operator()(const typename KernelTuple::func_type tgt, Args... args) {
for (int i = 0; i < FLAGS_burning; ++i) {
tgt(args...);
}
......@@ -109,31 +107,30 @@ struct BenchFunc {
namespace jit = paddle::operators::jit;
template <jit::KernelType KT, typename KernelTuples, typename PlaceType,
typename... Args>
void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
BenchFunc<KernelTuples, Args...> benchmark;
template <typename KernelTuple, typename PlaceType, typename... Args>
void BenchAllImpls(const typename KernelTuple::attr_type& attr, Args... args) {
BenchFunc<KernelTuple, Args...> benchmark;
std::vector<std::pair<std::string, double>> infos;
// test refer
auto refer = jit::GetRefer<KT, KernelTuples>();
auto refer = jit::GetRefer<KernelTuple>();
if (!refer) {
LOG(FATAL) << "Refer can not be empty!";
}
infos.push_back(std::make_pair("Refer", benchmark(refer, args...)));
// test jitcode
auto jitcode = jit::GetJitCode<KT, KernelTuples, PlaceType>(attr);
auto jitcode = jit::GetJitCode<KernelTuple, PlaceType>(attr);
if (jitcode) {
infos.push_back(std::make_pair("JitCode", benchmark(jitcode, args...)));
}
// test all impls in more
jit::KernelKey kkey(KT, PlaceType());
jit::KernelKey kkey(KernelTuple::kernel_type, PlaceType());
auto& pool = jit::KernelPool().Instance().AllKernels();
auto iter = pool.find(kkey);
if (iter != pool.end()) {
auto& impls = iter->second;
for (auto& impl : impls) {
auto i = dynamic_cast<const jit::KernelMore<KernelTuples>*>(impl.get());
auto i = dynamic_cast<const jit::KernelMore<KernelTuple>*>(impl.get());
if (i && i->UseMe(attr)) {
auto more = i->GetFunc();
infos.push_back(
......@@ -142,7 +139,7 @@ void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
}
}
// Test result from Get function
auto tgt = jit::KernelFuncs<KT, KernelTuples, PlaceType>::Cache().At(attr);
auto tgt = jit::KernelFuncs<KernelTuple, PlaceType>::Cache().At(attr);
if (!tgt) {
LOG(FATAL) << "Target can not be empty!";
}
......@@ -150,7 +147,8 @@ void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
// print
std::ostringstream loginfos;
loginfos << "Kernel Type " << jit::to_string(KT) << ": " << attr << ": ";
loginfos << "Kernel Type " << jit::to_string(KernelTuple::kernel_type) << ": "
<< attr << ": ";
for (auto pair : infos) {
loginfos << pair.first << " takes " << pair.second << " us; ";
}
......@@ -159,8 +157,9 @@ void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
using Tensor = paddle::framework::Tensor;
template <jit::KernelType KT, typename T, typename PlaceType>
void BenchXYZNKernel() {
template <typename KernelTuple, typename PlaceType>
void BenchKernelXYZN() {
using T = typename KernelTuple::data_type;
for (int d : TestSizes()) {
Tensor x, y, z;
x.Resize({d});
......@@ -171,16 +170,16 @@ void BenchXYZNKernel() {
T* z_data = z.mutable_data<T>(PlaceType());
RandomVec<T>(d, x_data);
RandomVec<T>(d, y_data);
BenchAllImpls<KT, jit::XYZNTuples<T>, PlaceType>(d, x.data<T>(),
y.data<T>(), z_data, d);
BenchAllImpls<KernelTuple, PlaceType>(d, x.data<T>(), y.data<T>(), z_data,
d);
// test inplace
BenchAllImpls<KT, jit::XYZNTuples<T>, PlaceType>(d, x.data<T>(), z_data,
z_data, d);
BenchAllImpls<KernelTuple, PlaceType>(d, x.data<T>(), z_data, z_data, d);
}
}
template <jit::KernelType KT, typename T, typename PlaceType>
void BenchAXYNKernel() {
template <typename KernelTuple, typename PlaceType>
void BenchKernelAXYN() {
using T = typename KernelTuple::data_type;
for (int d : TestSizes()) {
const T a = static_cast<T>(3);
Tensor x, y;
......@@ -189,26 +188,26 @@ void BenchAXYNKernel() {
T* x_data = x.mutable_data<T>(PlaceType());
T* y_data = y.mutable_data<T>(PlaceType());
RandomVec<T>(d, x_data);
BenchAllImpls<KT, jit::AXYNTuples<T>, PlaceType>(d, &a, x.data<T>(), y_data,
d);
BenchAllImpls<KernelTuple, PlaceType>(d, &a, x.data<T>(), y_data, d);
// test inplace
BenchAllImpls<KT, jit::AXYNTuples<T>, PlaceType>(d, &a, x.data<T>(), x_data,
d);
BenchAllImpls<KernelTuple, PlaceType>(d, &a, x.data<T>(), x_data, d);
}
}
template <jit::KernelType KT, typename T, typename PlaceType>
void BenchXRNKernel() {
template <typename KernelTuple, typename PlaceType>
void BenchKernelXRN() {
using T = typename KernelTuple::data_type;
for (int d : TestSizes()) {
Tensor x;
RandomVec<T>(d, x.mutable_data<T>({d}, PlaceType()));
T res;
BenchAllImpls<KT, jit::XRNTuples<T>, PlaceType>(d, x.data<T>(), &res, d);
BenchAllImpls<KernelTuple, PlaceType>(d, x.data<T>(), &res, d);
}
}
template <jit::KernelType KT, typename T, typename PlaceType>
void BenchXYNKernel() {
template <typename KernelTuple, typename PlaceType>
void BenchKernelXYN() {
using T = typename KernelTuple::data_type;
for (int d : TestSizes()) {
Tensor x, y;
x.Resize({d});
......@@ -216,12 +215,13 @@ void BenchXYNKernel() {
T* x_data = x.mutable_data<T>(PlaceType());
T* y_data = y.mutable_data<T>(PlaceType());
RandomVec<T>(d, x_data);
BenchAllImpls<KT, jit::XYNTuples<T>, PlaceType>(d, x.data<T>(), y_data, d);
BenchAllImpls<KernelTuple, PlaceType>(d, x.data<T>(), y_data, d);
}
}
template <jit::KernelType KT, typename T, typename PlaceType>
void BenchLSTMKernel() {
template <typename KernelTuple, typename PlaceType>
void BenchKernelLSTM() {
using T = typename KernelTuple::data_type;
for (bool use_peephole : {true, false}) {
for (int d : TestSizes()) {
const jit::lstm_attr_t attr(d, jit::kVSigmoid, jit::kVTanh, jit::kVTanh,
......@@ -252,13 +252,14 @@ void BenchLSTMKernel() {
step.wp = wp_data;
step.checked = checked_data;
}
BenchAllImpls<KT, jit::LSTMTuples<T>, PlaceType>(attr, &step, &attr);
BenchAllImpls<KernelTuple, PlaceType>(attr, &step, &attr);
}
}
}
template <jit::KernelType KT, typename T, typename PlaceType>
void BenchGRUKernel() {
template <typename KernelTuple, typename PlaceType>
void BenchKernelGRU() {
using T = typename KernelTuple::data_type;
for (int d : TestSizes()) {
const jit::gru_attr_t attr(d, jit::kVSigmoid, jit::kVTanh);
auto place = PlaceType();
......@@ -275,12 +276,13 @@ void BenchGRUKernel() {
step.gates = x_data;
step.ht_1 = ht_1_data;
step.ht = ht_data;
BenchAllImpls<KT, jit::GRUTuples<T>, PlaceType>(attr, &step, &attr);
BenchAllImpls<KernelTuple, PlaceType>(attr, &step, &attr);
}
}
template <jit::KernelType KT, typename T, typename PlaceType>
void BenchSeqPoolKernel() {
template <typename KernelTuple, typename PlaceType>
void BenchKernelSeqPool() {
using T = typename KernelTuple::data_type;
std::vector<jit::SeqPoolType> pool_types = {
jit::SeqPoolType::kSum, jit::SeqPoolType::kAvg, jit::SeqPoolType::kSqrt};
for (auto type : pool_types) {
......@@ -294,15 +296,15 @@ void BenchSeqPoolKernel() {
RandomVec<T>(h * w, x.mutable_data<T>(PlaceType()), -2.f, 2.f);
const T* x_data = x.data<T>();
T* y_data = y.mutable_data<T>(PlaceType());
BenchAllImpls<KT, jit::SeqPoolTuples<T>, PlaceType>(attr, x_data,
y_data, &attr);
BenchAllImpls<KernelTuple, PlaceType>(attr, x_data, y_data, &attr);
}
}
}
}
template <jit::KernelType KT, typename T, typename PlaceType>
void BenchEmbSeqPoolKernel() {
template <typename KernelTuple, typename PlaceType>
void BenchKernelEmbSeqPool() {
using T = typename KernelTuple::data_type;
std::vector<jit::SeqPoolType> pool_types = {jit::SeqPoolType::kSum};
int64_t tbl_h = 1e4;
for (int tbl_w : {10, 16, 256}) {
......@@ -324,16 +326,17 @@ void BenchEmbSeqPoolKernel() {
tbl_h - 1);
const int64_t* idx_data = idx.data<int64_t>();
T* o_data = out.mutable_data<T>(PlaceType());
BenchAllImpls<KT, jit::EmbSeqPoolTuples<T>, PlaceType>(
attr, table_data, idx_data, o_data, &attr);
BenchAllImpls<KernelTuple, PlaceType>(attr, table_data, idx_data,
o_data, &attr);
}
}
}
}
}
template <jit::KernelType KT, typename T, typename PlaceType>
void BenchSgdKernel() {
template <typename KernelTuple, typename PlaceType>
void BenchKernelSgd() {
using T = typename KernelTuple::data_type;
const T lr = 0.1;
auto UnDuplicatedRandomVec = [](int n, const int64_t lower,
const int64_t upper) -> std::vector<int64_t> {
......@@ -364,15 +367,16 @@ void BenchSgdKernel() {
const T* grad_data = grad.data<T>();
const int64_t* rows_data = rows.data();
jit::sgd_attr_t attr(param_h, grad_w, rows_size, grad_w, rows_size);
BenchAllImpls<KT, jit::SgdTuples<T>, PlaceType>(
attr, &lr, param_data, grad_data, rows_data, param_data, &attr);
BenchAllImpls<KernelTuple, PlaceType>(attr, &lr, param_data, grad_data,
rows_data, param_data, &attr);
}
}
}
}
template <jit::KernelType KT, typename T, typename PlaceType>
void BenchMatMulKernel() {
template <typename KernelTuple, typename PlaceType>
void BenchKernelMatMul() {
using T = typename KernelTuple::data_type;
for (int m : {1, 2, 3, 4}) {
for (int n : TestSizes()) {
for (int k : TestSizes()) {
......@@ -386,15 +390,16 @@ void BenchMatMulKernel() {
const T* b_data = b.data<T>();
T* c_data = c.mutable_data<T>(PlaceType());
const jit::matmul_attr_t attr{m, n, k};
BenchAllImpls<KT, jit::MatMulTuples<T>, PlaceType>(attr, a_data, b_data,
c_data, &attr);
BenchAllImpls<KernelTuple, PlaceType>(attr, a_data, b_data, c_data,
&attr);
}
}
}
}
template <jit::KernelType KT, typename T, typename PlaceType>
void BenchSoftmaxKernel() {
template <typename KernelTuple, typename PlaceType>
void BenchKernelSoftmax() {
using T = typename KernelTuple::data_type;
for (int bs : {1, 2, 10}) {
for (int n : TestSizes()) {
Tensor x, y;
......@@ -403,14 +408,14 @@ void BenchSoftmaxKernel() {
RandomVec<T>(bs * n, x.mutable_data<T>(PlaceType()), -2.f, 2.f);
const T* x_data = x.data<T>();
T* y_data = y.mutable_data<T>(PlaceType());
BenchAllImpls<KT, jit::SoftmaxTuples<T>, PlaceType>(n, x_data, y_data, n,
bs);
BenchAllImpls<KernelTuple, PlaceType>(n, x_data, y_data, n, bs);
}
}
}
template <jit::KernelType KT, typename T, typename PlaceType>
void BenchLayerNormKernel() {
template <typename KernelTuple, typename PlaceType>
void BenchKernelLayerNorm() {
using T = typename KernelTuple::data_type;
const T epsilon = 9.99999975e-06;
for (int n : {1, 2, 10}) {
for (int x_dim_0 : {1, 9, 17, 50}) {
......@@ -439,16 +444,17 @@ void BenchLayerNormKernel() {
T* var_data = var.data<T>();
T* out_data = out.mutable_data<T>(PlaceType());
BenchAllImpls<KT, jit::LayerNormTuples<T>, PlaceType>(
right, x_data, out_data, mean_data, var_data, scale_data, bias_data,
left, epsilon, right);
BenchAllImpls<KernelTuple, PlaceType>(right, x_data, out_data,
mean_data, var_data, scale_data,
bias_data, left, epsilon, right);
}
}
}
}
template <jit::KernelType KT, typename T, typename PlaceType>
void BenchCRFDecodingKernel() {
template <typename KernelTuple, typename PlaceType>
void BenchKernelCRFDecoding() {
using T = typename KernelTuple::data_type;
constexpr int state_trans_base_idx = 2;
for (int seq_len : {1, 11, 17, 50}) {
for (int tag_num : TestSizes()) {
......@@ -468,14 +474,15 @@ void BenchCRFDecodingKernel() {
T* alpha_data = alpha.mutable_data<T>(PlaceType());
int* track_data = track.mutable_data<int>(PlaceType());
BenchAllImpls<KT, jit::CRFDecodingTuples<T>, PlaceType>(
tag_num, seq_len, x_data, w_data, alpha_data, track_data, tag_num);
BenchAllImpls<KernelTuple, PlaceType>(tag_num, seq_len, x_data, w_data,
alpha_data, track_data, tag_num);
}
}
}
template <jit::KernelType KT, typename T, typename PlaceType>
void BenchVBroadcastKernel() {
template <typename KernelTuple, typename PlaceType>
void BenchKernelVBroadcast() {
using T = typename KernelTuple::data_type;
for (int64_t w : {1, 16, 64, 100, 256}) {
Tensor x;
x.Resize({w});
......@@ -485,78 +492,86 @@ void BenchVBroadcastKernel() {
Tensor y;
y.Resize({h * w});
T* y_data = y.mutable_data<T>(PlaceType());
BenchAllImpls<KT, jit::VBroadcastTuples<T>, PlaceType>(
w, x_data, y_data, static_cast<int64_t>(h), w);
BenchAllImpls<KernelTuple, PlaceType>(w, x_data, y_data,
static_cast<int64_t>(h), w);
}
}
}
using T = float;
using CPUPlace = paddle::platform::CPUPlace;
#define BenchKernelVMul BenchKernelXYZN
#define BenchKernelVAdd BenchKernelXYZN
#define BenchKernelVAddRelu BenchKernelXYZN
#define BenchKernelVSub BenchKernelXYZN
// xyzn
BENCH_FP32_CPU(kVMul) { BenchXYZNKernel<jit::kVMul, T, CPUPlace>(); }
BENCH_FP32_CPU(kVAdd) { BenchXYZNKernel<jit::kVAdd, T, CPUPlace>(); }
BENCH_FP32_CPU(kVAddRelu) { BenchXYZNKernel<jit::kVAddRelu, T, CPUPlace>(); }
BENCH_FP32_CPU(kVSub) { BenchXYZNKernel<jit::kVSub, T, CPUPlace>(); }
#define BenchKernelVScal BenchKernelAXYN
#define BenchKernelVAddBias BenchKernelAXYN
// axyn
BENCH_FP32_CPU(kVScal) { BenchAXYNKernel<jit::kVScal, T, CPUPlace>(); }
BENCH_FP32_CPU(kVAddBias) { BenchAXYNKernel<jit::kVAddBias, T, CPUPlace>(); }
#define BenchKernelVRelu BenchKernelXYN
#define BenchKernelVIdentity BenchKernelXYN
#define BenchKernelVSquare BenchKernelXYN
#define BenchKernelVExp BenchKernelXYN
#define BenchKernelVSigmoid BenchKernelXYN
#define BenchKernelVTanh BenchKernelXYN
#define BenchKernelVCopy BenchKernelXYN
// xrn
BENCH_FP32_CPU(kHSum) { BenchXRNKernel<jit::kHSum, T, CPUPlace>(); }
BENCH_FP32_CPU(kHMax) { BenchXRNKernel<jit::kHMax, T, CPUPlace>(); }
#define BenchKernelHMax BenchKernelXRN
#define BenchKernelHSum BenchKernelXRN
// xyn
BENCH_FP32_CPU(kVRelu) { BenchXYNKernel<jit::kVRelu, T, CPUPlace>(); }
BENCH_FP32_CPU(kVIdentity) { BenchXYNKernel<jit::kVIdentity, T, CPUPlace>(); }
BENCH_FP32_CPU(kVSquare) { BenchXYNKernel<jit::kVSquare, T, CPUPlace>(); }
BENCH_FP32_CPU(kVExp) { BenchXYNKernel<jit::kVExp, T, CPUPlace>(); }
BENCH_FP32_CPU(kVSigmoid) { BenchXYNKernel<jit::kVSigmoid, T, CPUPlace>(); }
BENCH_FP32_CPU(kVTanh) { BenchXYNKernel<jit::kVTanh, T, CPUPlace>(); }
BENCH_FP32_CPU(kVCopy) { BenchXYNKernel<jit::kVCopy, T, CPUPlace>(); }
// lstm and peephole
BENCH_FP32_CPU(kLSTMCtHt) { BenchLSTMKernel<jit::kLSTMCtHt, T, CPUPlace>(); }
BENCH_FP32_CPU(kLSTMC1H1) { BenchLSTMKernel<jit::kLSTMC1H1, T, CPUPlace>(); }
// gru functions
BENCH_FP32_CPU(kGRUH1) { BenchGRUKernel<jit::kGRUH1, T, CPUPlace>(); }
BENCH_FP32_CPU(kGRUHtPart1) { BenchGRUKernel<jit::kGRUHtPart1, T, CPUPlace>(); }
BENCH_FP32_CPU(kGRUHtPart2) { BenchGRUKernel<jit::kGRUHtPart2, T, CPUPlace>(); }
// seq pool function
BENCH_FP32_CPU(kSeqPool) { BenchSeqPoolKernel<jit::kSeqPool, T, CPUPlace>(); }
// embedding seq pool function
BENCH_FP32_CPU(kEmbSeqPool) {
BenchEmbSeqPoolKernel<jit::kEmbSeqPool, T, CPUPlace>();
}
#define BenchKernelLSTMCtHt BenchKernelLSTM
#define BenchKernelLSTMC1H1 BenchKernelLSTM
// sgd function
BENCH_FP32_CPU(kSgd) { BenchSgdKernel<jit::kSgd, T, CPUPlace>(); }
#define BenchKernelGRUH1 BenchKernelGRU
#define BenchKernelGRUHtPart1 BenchKernelGRU
#define BenchKernelGRUHtPart2 BenchKernelGRU
// matmul
BENCH_FP32_CPU(kMatMul) { BenchMatMulKernel<jit::kMatMul, T, CPUPlace>(); }
using CPUPlace = paddle::platform::CPUPlace;
// softmax
BENCH_FP32_CPU(kSoftmax) { BenchSoftmaxKernel<jit::kSoftmax, T, CPUPlace>(); }
#define BENCH_FP32_CPU(name) \
BENCH_JITKERNEL(name, FP32, CPU) { \
BenchKernel##name<jit::name##Tuple<float>, CPUPlace>(); \
}
// layernorm
BENCH_FP32_CPU(kLayerNorm) {
BenchLayerNormKernel<jit::kLayerNorm, T, CPUPlace>();
}
// xyzn
BENCH_FP32_CPU(VMul);
BENCH_FP32_CPU(VAdd);
BENCH_FP32_CPU(VAddRelu);
BENCH_FP32_CPU(VSub);
// crfdecoding
BENCH_FP32_CPU(kCRFDecoding) {
BenchCRFDecodingKernel<jit::kCRFDecoding, T, CPUPlace>();
}
// axyn
BENCH_FP32_CPU(VScal);
BENCH_FP32_CPU(VAddBias);
// vbroadcast function
BENCH_FP32_CPU(kVBroadcast) {
BenchVBroadcastKernel<jit::kVBroadcast, T, CPUPlace>();
}
// xyn
BENCH_FP32_CPU(VRelu);
BENCH_FP32_CPU(VIdentity);
BENCH_FP32_CPU(VSquare);
BENCH_FP32_CPU(VExp);
BENCH_FP32_CPU(VSigmoid);
BENCH_FP32_CPU(VTanh);
BENCH_FP32_CPU(VCopy);
// xrn
BENCH_FP32_CPU(HMax);
BENCH_FP32_CPU(HSum);
// LSTM
BENCH_FP32_CPU(LSTMCtHt);
BENCH_FP32_CPU(LSTMC1H1);
// GRU
BENCH_FP32_CPU(GRUH1);
BENCH_FP32_CPU(GRUHtPart1);
BENCH_FP32_CPU(GRUHtPart2);
BENCH_FP32_CPU(LayerNorm);
BENCH_FP32_CPU(CRFDecoding);
BENCH_FP32_CPU(SeqPool);
BENCH_FP32_CPU(EmbSeqPool);
BENCH_FP32_CPU(MatMul);
BENCH_FP32_CPU(Softmax);
BENCH_FP32_CPU(Sgd);
BENCH_FP32_CPU(VBroadcast);
// Benchmark all jit kernels including jitcode, mkl and refer.
// To use this tool, run command: ./benchmark [options...]
......
......@@ -19,6 +19,8 @@ extern "C" {
}
#include <iostream>
#include <string>
#include <unordered_map>
#include <utility> // for std::move
#include <vector>
#include "paddle/fluid/operators/jit/gen_base.h"
#include "paddle/fluid/operators/jit/kernel_base.h"
......@@ -30,22 +32,22 @@ namespace paddle {
namespace operators {
namespace jit {
template <KernelType KT, typename KernelTuples, typename PlaceType>
template <typename KernelTuple, typename PlaceType>
inline typename std::enable_if<
std::is_same<typename KernelTuples::data_type, float>::value &&
std::is_same<typename KernelTuple::data_type, float>::value &&
std::is_same<PlaceType, platform::CPUPlace>::value,
typename KernelTuples::func_type>::type
GetJitCode(const typename KernelTuples::attr_type& attr) {
using Func = typename KernelTuples::func_type;
using Attr = typename KernelTuples::attr_type;
typename KernelTuple::func_type>::type
GetJitCode(const typename KernelTuple::attr_type& attr) {
using Func = typename KernelTuple::func_type;
using Attr = typename KernelTuple::attr_type;
size_t key = JitCodeKey<Attr>(attr);
auto& codes = JitCodePool<KT>().Instance();
auto& codes = JitCodePool<KernelTuple::kernel_type>().Instance();
if (codes.Has(key)) {
return codes.AllKernels().at(key)->template getCode<Func>();
}
// creator is not related with attr, so can use KernelKey as key
KernelKey kkey(KT, PlaceType());
KernelKey kkey(KernelTuple::kernel_type, PlaceType());
// pool: (KernelKey(type, place), vector<GenCreatorPtr>)
auto& creator_map = JitCodeCreatorPool().Instance().AllCreators();
auto iter = creator_map.find(kkey);
......@@ -66,27 +68,27 @@ GetJitCode(const typename KernelTuples::attr_type& attr) {
return nullptr;
}
template <KernelType KT, typename KernelTuples, typename PlaceType>
template <typename KernelTuple, typename PlaceType>
inline typename std::enable_if<
!std::is_same<typename KernelTuples::data_type, float>::value ||
!std::is_same<typename KernelTuple::data_type, float>::value ||
!std::is_same<PlaceType, platform::CPUPlace>::value,
typename KernelTuples::func_type>::type
GetJitCode(const typename KernelTuples::attr_type& attr) {
typename KernelTuple::func_type>::type
GetJitCode(const typename KernelTuple::attr_type& attr) {
return nullptr;
}
// Refer code do not related with attr, which is just for cast
// Refer is always on CPUPlace
template <KernelType KT, typename KernelTuples>
inline typename KernelTuples::func_type GetRefer() {
template <typename KernelTuple>
inline typename KernelTuple::func_type GetRefer() {
auto& ref_pool = ReferKernelPool().Instance().AllKernels();
KernelKey kkey(KT, platform::CPUPlace());
KernelKey kkey(KernelTuple::kernel_type, platform::CPUPlace());
auto ref_iter = ref_pool.find(kkey);
PADDLE_ENFORCE(ref_iter != ref_pool.end(),
"Every Kernel should have reference function.");
auto& ref_impls = ref_iter->second;
for (auto& impl : ref_impls) {
auto i = dynamic_cast<const ReferKernel<KernelTuples>*>(impl.get());
auto i = dynamic_cast<const ReferKernel<KernelTuple>*>(impl.get());
if (i) {
return i->GetFunc();
}
......@@ -94,23 +96,22 @@ inline typename KernelTuples::func_type GetRefer() {
return nullptr;
}
template <KernelType KT, typename KernelTuples,
typename PlaceType = platform::CPUPlace>
typename KernelTuples::func_type Get(
const typename KernelTuples::attr_type& attr) {
auto jitfunc = GetJitCode<KT, KernelTuples, PlaceType>(attr);
template <typename KernelTuple, typename PlaceType = platform::CPUPlace>
typename KernelTuple::func_type Get(
const typename KernelTuple::attr_type& attr) {
auto jitfunc = GetJitCode<KernelTuple, PlaceType>(attr);
if (jitfunc) {
return jitfunc;
}
// pool: (KernelKey(type, place), vector<KernelPtr>)
KernelKey kkey(KT, PlaceType());
KernelKey kkey(KernelTuple::kernel_type, PlaceType());
auto& pool = KernelPool().Instance().AllKernels();
auto iter = pool.find(kkey);
if (iter != pool.end()) {
auto& impls = iter->second;
for (auto& impl : impls) {
auto i = dynamic_cast<const KernelMore<KernelTuples>*>(impl.get());
auto i = dynamic_cast<const KernelMore<KernelTuple>*>(impl.get());
if (i && i->UseMe(attr)) {
return i->GetFunc();
}
......@@ -118,48 +119,50 @@ typename KernelTuples::func_type Get(
}
// The last implementation should be reference function on CPUPlace.
return GetRefer<KT, KernelTuples>();
return GetRefer<KernelTuple>();
}
template <KernelType KT, typename KernelTuples, typename PlaceType>
template <typename KernelTuple, typename PlaceType>
class KernelFuncs {
public:
KernelFuncs() = default;
static KernelFuncs& Cache() {
static thread_local KernelFuncs<KT, KernelTuples, PlaceType> g_func_cache;
static thread_local KernelFuncs<KernelTuple, PlaceType> g_func_cache;
return g_func_cache;
}
// the exposed interface to use
typename KernelTuples::func_type At(
const typename KernelTuples::attr_type& attr) {
typename KernelTuple::func_type At(
const typename KernelTuple::attr_type& attr) {
// XXH64: 13.8 GB/s
int64_t key = XXH64(&attr, sizeof(typename KernelTuples::attr_type), 0);
// TODO(TJ): change me, maybe not all attr change need one key, should be
// attrkey
int64_t key = XXH64(&attr, sizeof(typename KernelTuple::attr_type), 0);
if (Has(key)) {
return funcs_.at(key);
}
// If do not have this attr in cache,
// then could run some runtime benchmark of this attr and save the best one.
// Here just get the offline benchmarked best one.
auto func = Get<KT, KernelTuples, PlaceType>(attr);
auto func = Get<KernelTuple, PlaceType>(attr);
Insert(key, func);
return func;
}
typename KernelTuples::func_type operator[](
const typename KernelTuples::attr_type& attr) {
typename KernelTuple::func_type operator[](
const typename KernelTuple::attr_type& attr) {
return At(attr);
}
protected:
bool Has(int64_t key) const { return funcs_.find(key) != funcs_.end(); }
void Insert(int64_t key, typename KernelTuples::func_type func) {
void Insert(int64_t key, typename KernelTuple::func_type func) {
funcs_.emplace(key, func);
}
private:
std::unordered_map<int64_t, typename KernelTuples::func_type> funcs_;
std::unordered_map<int64_t, typename KernelTuple::func_type> funcs_;
DISABLE_COPY_AND_ASSIGN(KernelFuncs);
};
......
......@@ -62,26 +62,55 @@ typedef enum {
kSqrt,
} SeqPoolType;
// x, y, z, n
template <typename T>
struct XYZNTuples {
struct XYZNTuple {
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(const T*, const T*, T*, int);
};
// a, x, y, n
template <typename T>
struct AXYNTuples : public XYZNTuples<T> {};
struct AXYNTuple : public XYZNTuple<T> {};
// x, y, n
template <typename T>
struct XYNTuples {
struct XYNTuple {
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(const T*, T*, int);
};
// x, return and int
// x, returned value, n
template <typename T>
struct XRNTuples : public XYNTuples<T> {};
struct XRNTuple : public XYNTuple<T> {};
#define DECLARE_KERNELTUPLE(kernel_tuple, type) \
template <typename T> \
struct type##Tuple : public kernel_tuple<T> { \
static constexpr KernelType kernel_type = k##type; \
}
// Tuple should be corresponding to the KernelType
DECLARE_KERNELTUPLE(XYZNTuple, VMul);
DECLARE_KERNELTUPLE(XYZNTuple, VAdd);
DECLARE_KERNELTUPLE(XYZNTuple, VAddRelu);
DECLARE_KERNELTUPLE(XYZNTuple, VSub);
DECLARE_KERNELTUPLE(AXYNTuple, VScal);
DECLARE_KERNELTUPLE(AXYNTuple, VAddBias);
DECLARE_KERNELTUPLE(XYNTuple, VRelu);
DECLARE_KERNELTUPLE(XYNTuple, VIdentity);
DECLARE_KERNELTUPLE(XYNTuple, VSquare);
DECLARE_KERNELTUPLE(XYNTuple, VExp);
DECLARE_KERNELTUPLE(XYNTuple, VSigmoid);
DECLARE_KERNELTUPLE(XYNTuple, VTanh);
DECLARE_KERNELTUPLE(XYNTuple, VCopy);
DECLARE_KERNELTUPLE(XRNTuple, HMax);
DECLARE_KERNELTUPLE(XRNTuple, HSum);
typedef struct {
void* gates; // gates: x_ch, x_ih, x_fh, x_oh
......@@ -122,21 +151,31 @@ typedef struct rnn_attr_s gru_attr_t;
typedef struct lstm_attr_s lstm_attr_t;
template <typename T>
struct LSTMTuples {
struct LSTMTuple {
typedef T data_type;
typedef lstm_attr_t attr_type;
typedef void (*func_type)(lstm_t*, const lstm_attr_t*);
};
template <typename T>
struct GRUTuples {
struct GRUTuple {
typedef T data_type;
typedef gru_attr_t attr_type;
typedef void (*func_type)(gru_t*, const gru_attr_t*);
};
DECLARE_KERNELTUPLE(LSTMTuple, LSTMCtHt);
DECLARE_KERNELTUPLE(LSTMTuple, LSTMC1H1);
DECLARE_KERNELTUPLE(GRUTuple, GRUH1);
DECLARE_KERNELTUPLE(GRUTuple, GRUHtPart1);
DECLARE_KERNELTUPLE(GRUTuple, GRUHtPart2);
#undef DECLARE_KERNELTUPLE
template <typename T>
struct VBroadcastTuples {
struct VBroadcastTuple {
static constexpr KernelType kernel_type = kVBroadcast;
typedef T data_type;
typedef int64_t attr_type;
typedef void (*func_type)(const T*, T*, int64_t, int64_t);
......@@ -151,7 +190,8 @@ typedef struct seq_pool_attr_s {
} seq_pool_attr_t;
template <typename T>
struct SeqPoolTuples {
struct SeqPoolTuple {
static constexpr KernelType kernel_type = kSeqPool;
typedef T data_type;
typedef seq_pool_attr_t attr_type;
typedef void (*func_type)(const T*, T*, const seq_pool_attr_t*);
......@@ -176,7 +216,8 @@ typedef struct emb_seq_pool_attr_s {
} emb_seq_pool_attr_t;
template <typename T>
struct EmbSeqPoolTuples {
struct EmbSeqPoolTuple {
static constexpr KernelType kernel_type = kEmbSeqPool;
typedef T data_type;
typedef emb_seq_pool_attr_t attr_type;
typedef void (*func_type)(const T*, const int64_t*, T*,
......@@ -198,7 +239,8 @@ typedef struct sgd_attr_s {
} sgd_attr_t;
template <typename T>
struct SgdTuples {
struct SgdTuple {
static constexpr KernelType kernel_type = kSgd;
typedef T data_type;
typedef sgd_attr_t attr_type;
typedef void (*func_type)(const T*, const T*, const T*, const int64_t*, T*,
......@@ -214,21 +256,24 @@ typedef struct matmul_attr_s {
} matmul_attr_t;
template <typename T>
struct MatMulTuples {
struct MatMulTuple {
static constexpr KernelType kernel_type = kMatMul;
typedef T data_type;
typedef matmul_attr_t attr_type;
typedef void (*func_type)(const T*, const T*, T*, const matmul_attr_t*);
};
template <typename T>
struct CRFDecodingTuples {
struct CRFDecodingTuple {
static constexpr KernelType kernel_type = kCRFDecoding;
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(const int, const T*, const T*, T*, int*, int);
};
template <typename T>
struct LayerNormTuples {
struct LayerNormTuple {
static constexpr KernelType kernel_type = kLayerNorm;
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(T*, T*, T*, T*, const T*, const T*, int,
......@@ -236,7 +281,8 @@ struct LayerNormTuples {
};
template <typename T>
struct SoftmaxTuples {
struct SoftmaxTuple {
static constexpr KernelType kernel_type = kSoftmax;
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(const T*, T*, int, int);
......@@ -244,7 +290,8 @@ struct SoftmaxTuples {
// nChw16c = nChw16c .* NC
template <typename T>
struct NCHW16CMulNCTuples {
struct NCHW16CMulNCTuple {
static constexpr KernelType kernel_type = kNCHW16CMulNC;
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(const T*, const T*, T*, int, int);
......@@ -258,12 +305,12 @@ class Kernel {
DISABLE_COPY_AND_ASSIGN(Kernel);
};
template <typename KernelTuples>
template <typename KernelTuple>
class KernelMore : public Kernel {
public:
using T = typename KernelTuples::data_type;
using Func = typename KernelTuples::func_type;
using Attr = typename KernelTuples::attr_type;
using T = typename KernelTuple::data_type;
using Func = typename KernelTuple::func_type;
using Attr = typename KernelTuple::attr_type;
virtual Func GetFunc() const { return func; }
virtual bool UseMe(const Attr& attr) const = 0;
virtual const char* ImplType() const = 0;
......@@ -272,11 +319,11 @@ class KernelMore : public Kernel {
Func func{nullptr};
};
template <typename KernelTuples>
class ReferKernel : public KernelMore<KernelTuples> {
template <typename KernelTuple>
class ReferKernel : public KernelMore<KernelTuple> {
public:
// Refer code can always be used
bool UseMe(const typename KernelTuples::attr_type& attr) const override {
bool UseMe(const typename KernelTuple::attr_type& attr) const override {
return true;
}
const char* ImplType() const override { return "Refer"; }
......
......@@ -26,11 +26,10 @@ namespace intrinsic {
void CRFDecoding(const int seq_len, const float* x, const float* w,
float* alpha, int* track, int tag_num);
class CRFDecodingKernel : public KernelMore<CRFDecodingTuples<float>> {
class CRFDecodingKernel : public KernelMore<CRFDecodingTuple<float>> {
public:
CRFDecodingKernel() { this->func = CRFDecoding; }
bool UseMe(
const typename CRFDecodingTuples<float>::attr_type&) const override;
bool UseMe(const typename CRFDecodingTuple<float>::attr_type&) const override;
const char* ImplType() const override { return "Intrinsic"; }
};
......
......@@ -27,10 +27,10 @@ void LayerNorm(float* x, float* out, float* mean, float* var,
const float* scale, const float* bias, int height,
const float epsilon, int right);
class LayerNormKernel : public KernelMore<LayerNormTuples<float>> {
class LayerNormKernel : public KernelMore<LayerNormTuple<float>> {
public:
LayerNormKernel() { this->func = LayerNorm; }
bool UseMe(const typename LayerNormTuples<float>::attr_type&) const override;
bool UseMe(const typename LayerNormTuple<float>::attr_type&) const override;
const char* ImplType() const override { return "Intrinsic"; }
};
......
......@@ -23,6 +23,8 @@ namespace jit {
namespace more {
namespace mix {
using CPUPlace = platform::CPUPlace;
void VSigmoid(const T* x, T* y, int n) {
const float min = SIGMOID_THRESHOLD_MIN;
const float max = SIGMOID_THRESHOLD_MAX;
......@@ -30,7 +32,7 @@ void VSigmoid(const T* x, T* y, int n) {
y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
y[i] = static_cast<T>(0) - y[i];
}
auto compute = Get<KernelType::kVExp, XYNTuples<T>, platform::CPUPlace>(n);
auto compute = KernelFuncs<VExpTuple<T>, CPUPlace>::Cache().At(n);
compute(y, y, n);
for (int i = 0; i < n; ++i) {
y[i] = static_cast<T>(1) / (static_cast<T>(1) + y[i]);
......@@ -39,9 +41,9 @@ void VSigmoid(const T* x, T* y, int n) {
void VTanh(const T* x, T* y, int n) {
const T a = 2, b = -1;
auto compute_scal = Get<kVScal, AXYNTuples<T>, platform::CPUPlace>(n);
auto compute_addbias = Get<kVAddBias, AXYNTuples<T>, platform::CPUPlace>(n);
auto compute_sigmoid = Get<kVSigmoid, XYNTuples<T>, platform::CPUPlace>(n);
auto compute_scal = KernelFuncs<VScalTuple<T>, CPUPlace>::Cache().At(n);
auto compute_addbias = KernelFuncs<VAddBiasTuple<T>, CPUPlace>::Cache().At(n);
auto compute_sigmoid = KernelFuncs<VSigmoidTuple<T>, CPUPlace>::Cache().At(n);
compute_scal(&a, x, y, n);
compute_sigmoid(y, y, n);
compute_scal(&a, y, y, n);
......@@ -49,16 +51,12 @@ void VTanh(const T* x, T* y, int n) {
}
void Softmax(const T* x, T* y, int n, int bs) {
auto compute_hmax =
KernelFuncs<kHMax, XRNTuples<T>, platform::CPUPlace>::Cache().At(n);
auto compute_hsum =
KernelFuncs<kHSum, XRNTuples<T>, platform::CPUPlace>::Cache().At(n);
auto compute_vscal =
KernelFuncs<kVScal, AXYNTuples<T>, platform::CPUPlace>::Cache().At(n);
auto compute_hmax = KernelFuncs<HMaxTuple<T>, CPUPlace>::Cache().At(n);
auto compute_hsum = KernelFuncs<HSumTuple<T>, CPUPlace>::Cache().At(n);
auto compute_vscal = KernelFuncs<VScalTuple<T>, CPUPlace>::Cache().At(n);
auto compute_vaddbias =
KernelFuncs<kVAddBias, AXYNTuples<T>, platform::CPUPlace>::Cache().At(n);
auto compute_vexp =
KernelFuncs<kVExp, XYNTuples<T>, platform::CPUPlace>::Cache().At(n);
KernelFuncs<VAddBiasTuple<T>, CPUPlace>::Cache().At(n);
auto compute_vexp = KernelFuncs<VExpTuple<T>, CPUPlace>::Cache().At(n);
for (int i = 0; i < bs; ++i) {
T scalar;
......@@ -76,13 +74,13 @@ void Softmax(const T* x, T* y, int n, int bs) {
void (*getActFunc(KernelType type, int d))(const T*, T*, int) { // NOLINT
if (type == kVSigmoid) {
return Get<kVSigmoid, XYNTuples<T>, platform::CPUPlace>(d);
return KernelFuncs<VSigmoidTuple<T>, CPUPlace>::Cache().At(d);
} else if (type == kVRelu) {
return Get<kVRelu, XYNTuples<T>, platform::CPUPlace>(d);
return KernelFuncs<VReluTuple<T>, CPUPlace>::Cache().At(d);
} else if (type == kVTanh) {
return Get<kVTanh, XYNTuples<T>, platform::CPUPlace>(d);
return KernelFuncs<VTanhTuple<T>, CPUPlace>::Cache().At(d);
} else if (type == kVIdentity) {
return Get<kVIdentity, XYNTuples<T>, platform::CPUPlace>(d);
return KernelFuncs<VIdentityTuple<T>, CPUPlace>::Cache().At(d);
}
PADDLE_THROW("Not support type: %s", type);
return nullptr;
......@@ -98,9 +96,9 @@ void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr) {
const int d = attr->d;
const int d2 = d * 2;
const int d3 = d * 3;
auto vmul_d = Get<kVMul, XYZNTuples<T>, platform::CPUPlace>(d);
auto vadd_d = Get<kVAdd, XYZNTuples<T>, platform::CPUPlace>(d);
auto vadd_d2 = Get<kVAdd, XYZNTuples<T>, platform::CPUPlace>(d2);
auto vmul_d = KernelFuncs<VMulTuple<T>, CPUPlace>::Cache().At(d);
auto vadd_d = KernelFuncs<VAddTuple<T>, CPUPlace>::Cache().At(d);
auto vadd_d2 = KernelFuncs<VAddTuple<T>, CPUPlace>::Cache().At(d2);
auto act_gate_d = getActFunc(attr->act_gate, d);
auto act_gate_d2 = getActFunc(attr->act_gate, d2);
auto act_gate_d3 = getActFunc(attr->act_gate, d3);
......@@ -140,8 +138,8 @@ void LSTMC1H1(lstm_t* step, const lstm_attr_t* attr) {
int d = attr->d;
int d2 = d * 2;
int d3 = d * 3;
auto vmul_d = Get<kVMul, XYZNTuples<T>, platform::CPUPlace>(d);
auto vadd_d = Get<kVAdd, XYZNTuples<T>, platform::CPUPlace>(d);
auto vmul_d = KernelFuncs<VMulTuple<T>, CPUPlace>::Cache().At(d);
auto vadd_d = KernelFuncs<VAddTuple<T>, CPUPlace>::Cache().At(d);
auto act_gate_d = getActFunc(attr->act_gate, d);
auto act_cand_d = getActFunc(attr->act_cand, d);
auto act_cell_d = getActFunc(attr->act_cell, d);
......@@ -169,7 +167,7 @@ void GRUH1(gru_t* step, const gru_attr_t* attr) {
int d2 = d * 2;
auto act_gate = getActFunc(attr->act_gate, d);
auto act_cand = getActFunc(attr->act_cand, d);
auto vmul_d = Get<kVMul, XYZNTuples<T>, platform::CPUPlace>(d);
auto vmul_d = KernelFuncs<VMulTuple<T>, CPUPlace>::Cache().At(d);
act_gate(gates, gates, d);
act_cand(gates + d2, gates + d2, d);
vmul_d(gates, gates + d2, ht, d);
......@@ -182,7 +180,7 @@ void GRUHtPart1(gru_t* step, const gru_attr_t* attr) {
T* ht = reinterpret_cast<T*>(step->ht);
const T* ht_1 = reinterpret_cast<const T*>(step->ht_1);
auto act_gate = getActFunc(attr->act_gate, attr->d);
auto vmul_d = Get<kVMul, XYZNTuples<T>, platform::CPUPlace>(attr->d);
auto vmul_d = KernelFuncs<VMulTuple<T>, CPUPlace>::Cache().At(attr->d);
act_gate(gates + attr->d, gates + attr->d, attr->d);
vmul_d(ht_1, gates + attr->d, ht, attr->d);
}
......@@ -230,16 +228,16 @@ bool GRUHtPart2Kernel::UseMe(const gru_attr_t& attr) const { return true; }
namespace mix = paddle::operators::jit::more::mix;
#define REGISTER_MORE_KERNEL(key, func) \
REGISTER_JITKERNEL_MORE(key, mix, mix::func##Kernel)
REGISTER_MORE_KERNEL(kVSigmoid, VSigmoid);
REGISTER_MORE_KERNEL(kVTanh, VTanh);
REGISTER_MORE_KERNEL(kSoftmax, Softmax);
REGISTER_MORE_KERNEL(kLSTMCtHt, LSTMCtHt);
REGISTER_MORE_KERNEL(kLSTMC1H1, LSTMC1H1);
REGISTER_MORE_KERNEL(kGRUH1, GRUH1);
REGISTER_MORE_KERNEL(kGRUHtPart1, GRUHtPart1);
REGISTER_MORE_KERNEL(kGRUHtPart2, GRUHtPart2);
#define REGISTER_MORE_KERNEL(func) \
REGISTER_JITKERNEL_MORE(k##func, mix, mix::func##Kernel)
REGISTER_MORE_KERNEL(VSigmoid);
REGISTER_MORE_KERNEL(VTanh);
REGISTER_MORE_KERNEL(Softmax);
REGISTER_MORE_KERNEL(LSTMCtHt);
REGISTER_MORE_KERNEL(LSTMC1H1);
REGISTER_MORE_KERNEL(GRUH1);
REGISTER_MORE_KERNEL(GRUHtPart1);
REGISTER_MORE_KERNEL(GRUHtPart2);
#undef REGISTER_MORE_KERNEL
......@@ -34,27 +34,27 @@ void GRUH1(gru_t* step, const gru_attr_t* attr);
void GRUHtPart1(gru_t* step, const gru_attr_t* attr);
void GRUHtPart2(gru_t* step, const gru_attr_t* attr);
#define DECLARE_MORE_KERNEL(name, tuples) \
class name##Kernel : public KernelMore<tuples<T>> { \
#define DECLARE_MORE_KERNEL(name) \
class name##Kernel : public KernelMore<name##Tuple<T>> { \
public: \
name##Kernel() { this->func = name; } \
bool UseMe(const typename tuples<T>::attr_type&) const override; \
bool UseMe(const typename name##Tuple<T>::attr_type&) const override; \
const char* ImplType() const override { return "Mixed"; } \
}
// XYN
DECLARE_MORE_KERNEL(VSigmoid, XYNTuples);
DECLARE_MORE_KERNEL(VTanh, XYNTuples);
DECLARE_MORE_KERNEL(VSigmoid);
DECLARE_MORE_KERNEL(VTanh);
// XRN
DECLARE_MORE_KERNEL(Softmax, SoftmaxTuples);
DECLARE_MORE_KERNEL(Softmax);
DECLARE_MORE_KERNEL(LSTMCtHt, LSTMTuples);
DECLARE_MORE_KERNEL(LSTMC1H1, LSTMTuples);
DECLARE_MORE_KERNEL(LSTMCtHt);
DECLARE_MORE_KERNEL(LSTMC1H1);
DECLARE_MORE_KERNEL(GRUH1, GRUTuples);
DECLARE_MORE_KERNEL(GRUHtPart1, GRUTuples);
DECLARE_MORE_KERNEL(GRUHtPart2, GRUTuples);
DECLARE_MORE_KERNEL(GRUH1);
DECLARE_MORE_KERNEL(GRUHtPart1);
DECLARE_MORE_KERNEL(GRUHtPart2);
#undef DECLARE_MORE_KERNEL
......
......@@ -250,23 +250,23 @@ AWALYS_USE_ME_WITH_DOUBLE(Softmax);
namespace mkl = paddle::operators::jit::more::mkl;
#define REGISTER_MKL_KERNEL(key, func) \
REGISTER_JITKERNEL_MORE(key, mkl, mkl::func##Kernel<float>, \
#define REGISTER_MKL_KERNEL(func) \
REGISTER_JITKERNEL_MORE(k##func, mkl, mkl::func##Kernel<float>, \
mkl::func##Kernel<double>)
REGISTER_MKL_KERNEL(kMatMul, MatMul);
REGISTER_MKL_KERNEL(kVMul, VMul);
REGISTER_MKL_KERNEL(kVAdd, VAdd);
REGISTER_MKL_KERNEL(kVScal, VScal);
REGISTER_MKL_KERNEL(kVExp, VExp);
REGISTER_MKL_KERNEL(kVSquare, VSquare);
REGISTER_MKL_KERNEL(kVCopy, VCopy);
REGISTER_MKL_KERNEL(kVBroadcast, VBroadcast);
REGISTER_MKL_KERNEL(kVSigmoid, VSigmoid);
REGISTER_MKL_KERNEL(kVTanh, VTanh);
REGISTER_MKL_KERNEL(kSeqPool, SeqPool);
REGISTER_MKL_KERNEL(kEmbSeqPool, EmbSeqPool);
REGISTER_MKL_KERNEL(kSoftmax, Softmax);
REGISTER_MKL_KERNEL(kSgd, Sgd);
REGISTER_MKL_KERNEL(MatMul);
REGISTER_MKL_KERNEL(VMul);
REGISTER_MKL_KERNEL(VAdd);
REGISTER_MKL_KERNEL(VScal);
REGISTER_MKL_KERNEL(VExp);
REGISTER_MKL_KERNEL(VSquare);
REGISTER_MKL_KERNEL(VCopy);
REGISTER_MKL_KERNEL(VBroadcast);
REGISTER_MKL_KERNEL(VSigmoid);
REGISTER_MKL_KERNEL(VTanh);
REGISTER_MKL_KERNEL(SeqPool);
REGISTER_MKL_KERNEL(EmbSeqPool);
REGISTER_MKL_KERNEL(Softmax);
REGISTER_MKL_KERNEL(Sgd);
#undef REGISTER_MKL_KERNEL
......@@ -175,41 +175,38 @@ void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows,
}
}
#define DECLARE_MKL_KERNEL(name, tuples) \
#define DECLARE_MKL_KERNEL(name) \
template <typename T> \
class name##Kernel : public KernelMore<tuples<T>> { \
class name##Kernel : public KernelMore<name##Tuple<T>> { \
public: \
name##Kernel() { this->func = name<T>; } \
bool UseMe(const typename tuples<T>::attr_type&) const override; \
bool UseMe(const typename name##Tuple<T>::attr_type&) const override; \
const char* ImplType() const override { return "MKL"; } \
}
// ABCMNK
DECLARE_MKL_KERNEL(MatMul, MatMulTuples);
DECLARE_MKL_KERNEL(MatMul);
// XYZN
DECLARE_MKL_KERNEL(VMul, XYZNTuples);
DECLARE_MKL_KERNEL(VAdd, XYZNTuples);
DECLARE_MKL_KERNEL(VMul);
DECLARE_MKL_KERNEL(VAdd);
// AXYN
DECLARE_MKL_KERNEL(VScal, AXYNTuples);
DECLARE_MKL_KERNEL(VScal);
// XYN
DECLARE_MKL_KERNEL(VExp, XYNTuples);
DECLARE_MKL_KERNEL(VSigmoid, XYNTuples);
DECLARE_MKL_KERNEL(VTanh, XYNTuples);
DECLARE_MKL_KERNEL(VSquare, XYNTuples);
DECLARE_MKL_KERNEL(VCopy, XYNTuples);
DECLARE_MKL_KERNEL(SeqPool, SeqPoolTuples);
DECLARE_MKL_KERNEL(EmbSeqPool, EmbSeqPoolTuples);
DECLARE_MKL_KERNEL(Softmax, SoftmaxTuples);
DECLARE_MKL_KERNEL(Sgd, SgdTuples);
DECLARE_MKL_KERNEL(VBroadcast, VBroadcastTuples);
DECLARE_MKL_KERNEL(VExp);
DECLARE_MKL_KERNEL(VSigmoid);
DECLARE_MKL_KERNEL(VTanh);
DECLARE_MKL_KERNEL(VSquare);
DECLARE_MKL_KERNEL(VCopy);
// others
DECLARE_MKL_KERNEL(SeqPool);
DECLARE_MKL_KERNEL(EmbSeqPool);
DECLARE_MKL_KERNEL(Softmax);
DECLARE_MKL_KERNEL(Sgd);
DECLARE_MKL_KERNEL(VBroadcast);
#undef DECLARE_MKL_KERNEL
......
......@@ -17,51 +17,43 @@
namespace refer = paddle::operators::jit::refer;
#define REGISTER_REFER_KERNEL(key, func) \
REGISTER_JITKERNEL_REFER(key, refer::func##Kernel<float>, \
#define REGISTER_REFER_KERNEL(func) \
REGISTER_JITKERNEL_REFER(k##func, refer::func##Kernel<float>, \
refer::func##Kernel<double>)
REGISTER_REFER_KERNEL(kVMul, VMul);
REGISTER_REFER_KERNEL(kVAdd, VAdd);
REGISTER_REFER_KERNEL(kVAddRelu, VAddRelu);
REGISTER_REFER_KERNEL(kVSub, VSub);
REGISTER_REFER_KERNEL(kVScal, VScal);
REGISTER_REFER_KERNEL(kVAddBias, VAddBias);
REGISTER_REFER_KERNEL(kVRelu, VRelu);
REGISTER_REFER_KERNEL(kVCopy, VCopy);
REGISTER_REFER_KERNEL(kVIdentity, VIdentity);
REGISTER_REFER_KERNEL(kVSquare, VSquare);
REGISTER_REFER_KERNEL(kVExp, VExp);
REGISTER_REFER_KERNEL(kVSigmoid, VSigmoid);
REGISTER_REFER_KERNEL(kVTanh, VTanh);
REGISTER_REFER_KERNEL(kLSTMCtHt, LSTMCtHt);
REGISTER_REFER_KERNEL(kLSTMC1H1, LSTMC1H1);
REGISTER_REFER_KERNEL(kGRUH1, GRUH1);
REGISTER_REFER_KERNEL(kGRUHtPart1, GRUHtPart1);
REGISTER_REFER_KERNEL(kGRUHtPart2, GRUHtPart2);
REGISTER_REFER_KERNEL(kCRFDecoding, CRFDecoding);
REGISTER_REFER_KERNEL(kLayerNorm, LayerNorm);
REGISTER_REFER_KERNEL(kNCHW16CMulNC, NCHW16CMulNC);
REGISTER_REFER_KERNEL(kSeqPool, SeqPool);
REGISTER_REFER_KERNEL(kMatMul, MatMul);
REGISTER_REFER_KERNEL(kHMax, HMax);
REGISTER_REFER_KERNEL(kHSum, HSum);
REGISTER_REFER_KERNEL(kSoftmax, Softmax);
REGISTER_REFER_KERNEL(kEmbSeqPool, EmbSeqPool);
REGISTER_REFER_KERNEL(kSgd, Sgd);
REGISTER_REFER_KERNEL(kVBroadcast, VBroadcast);
REGISTER_REFER_KERNEL(VMul);
REGISTER_REFER_KERNEL(VAdd);
REGISTER_REFER_KERNEL(VAddRelu);
REGISTER_REFER_KERNEL(VSub);
REGISTER_REFER_KERNEL(VScal);
REGISTER_REFER_KERNEL(VAddBias);
REGISTER_REFER_KERNEL(VRelu);
REGISTER_REFER_KERNEL(VCopy);
REGISTER_REFER_KERNEL(VIdentity);
REGISTER_REFER_KERNEL(VSquare);
REGISTER_REFER_KERNEL(VExp);
REGISTER_REFER_KERNEL(VSigmoid);
REGISTER_REFER_KERNEL(VTanh);
REGISTER_REFER_KERNEL(LSTMCtHt);
REGISTER_REFER_KERNEL(LSTMC1H1);
REGISTER_REFER_KERNEL(GRUH1);
REGISTER_REFER_KERNEL(GRUHtPart1);
REGISTER_REFER_KERNEL(GRUHtPart2);
REGISTER_REFER_KERNEL(CRFDecoding);
REGISTER_REFER_KERNEL(LayerNorm);
REGISTER_REFER_KERNEL(NCHW16CMulNC);
REGISTER_REFER_KERNEL(SeqPool);
REGISTER_REFER_KERNEL(MatMul);
REGISTER_REFER_KERNEL(HMax);
REGISTER_REFER_KERNEL(HSum);
REGISTER_REFER_KERNEL(Softmax);
REGISTER_REFER_KERNEL(EmbSeqPool);
REGISTER_REFER_KERNEL(Sgd);
REGISTER_REFER_KERNEL(VBroadcast);
#undef REGISTER_REFER_KERNEL
......@@ -490,60 +490,54 @@ void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows,
}
}
#define DECLARE_REFER_KERNEL(name, tuples) \
#define DECLARE_REFER_KERNEL(name) \
template <typename T> \
class name##Kernel : public ReferKernel<tuples<T>> { \
class name##Kernel : public ReferKernel<name##Tuple<T>> { \
public: \
name##Kernel() { this->func = name<T>; } \
}
// const T* x, const T* y, T* z, int n
DECLARE_REFER_KERNEL(VMul, XYZNTuples);
DECLARE_REFER_KERNEL(VAdd, XYZNTuples);
DECLARE_REFER_KERNEL(VAddRelu, XYZNTuples);
DECLARE_REFER_KERNEL(VSub, XYZNTuples);
DECLARE_REFER_KERNEL(VMul);
DECLARE_REFER_KERNEL(VAdd);
DECLARE_REFER_KERNEL(VAddRelu);
DECLARE_REFER_KERNEL(VSub);
// const T* a, const T* x, T* y, int n
DECLARE_REFER_KERNEL(VScal, AXYNTuples);
DECLARE_REFER_KERNEL(VAddBias, AXYNTuples);
DECLARE_REFER_KERNEL(VScal);
DECLARE_REFER_KERNEL(VAddBias);
// const T* x, T* y, int n
DECLARE_REFER_KERNEL(VRelu, XYNTuples);
DECLARE_REFER_KERNEL(VIdentity, XYNTuples);
DECLARE_REFER_KERNEL(VExp, XYNTuples);
DECLARE_REFER_KERNEL(VSigmoid, XYNTuples);
DECLARE_REFER_KERNEL(VTanh, XYNTuples);
DECLARE_REFER_KERNEL(VSquare, XYNTuples);
DECLARE_REFER_KERNEL(VCopy, XYNTuples);
DECLARE_REFER_KERNEL(VRelu);
DECLARE_REFER_KERNEL(VIdentity);
DECLARE_REFER_KERNEL(VExp);
DECLARE_REFER_KERNEL(VSigmoid);
DECLARE_REFER_KERNEL(VTanh);
DECLARE_REFER_KERNEL(VSquare);
DECLARE_REFER_KERNEL(VCopy);
// lstm_t*, const lstm_attr_t*
DECLARE_REFER_KERNEL(LSTMCtHt, LSTMTuples);
DECLARE_REFER_KERNEL(LSTMC1H1, LSTMTuples);
DECLARE_REFER_KERNEL(LSTMCtHt);
DECLARE_REFER_KERNEL(LSTMC1H1);
// gru_t*, const gru_attr_t*
DECLARE_REFER_KERNEL(GRUH1, GRUTuples);
DECLARE_REFER_KERNEL(GRUHtPart1, GRUTuples);
DECLARE_REFER_KERNEL(GRUHtPart2, GRUTuples);
DECLARE_REFER_KERNEL(CRFDecoding, CRFDecodingTuples);
DECLARE_REFER_KERNEL(LayerNorm, LayerNormTuples);
DECLARE_REFER_KERNEL(NCHW16CMulNC, NCHW16CMulNCTuples);
DECLARE_REFER_KERNEL(SeqPool, SeqPoolTuples);
DECLARE_REFER_KERNEL(MatMul, MatMulTuples);
DECLARE_REFER_KERNEL(HMax, XRNTuples);
DECLARE_REFER_KERNEL(HSum, XRNTuples);
DECLARE_REFER_KERNEL(Softmax, SoftmaxTuples);
DECLARE_REFER_KERNEL(EmbSeqPool, EmbSeqPoolTuples);
DECLARE_REFER_KERNEL(Sgd, SgdTuples);
DECLARE_REFER_KERNEL(VBroadcast, VBroadcastTuples);
DECLARE_REFER_KERNEL(GRUH1);
DECLARE_REFER_KERNEL(GRUHtPart1);
DECLARE_REFER_KERNEL(GRUHtPart2);
DECLARE_REFER_KERNEL(HMax);
DECLARE_REFER_KERNEL(HSum);
// others
DECLARE_REFER_KERNEL(CRFDecoding);
DECLARE_REFER_KERNEL(LayerNorm);
DECLARE_REFER_KERNEL(NCHW16CMulNC);
DECLARE_REFER_KERNEL(SeqPool);
DECLARE_REFER_KERNEL(MatMul);
DECLARE_REFER_KERNEL(Softmax);
DECLARE_REFER_KERNEL(EmbSeqPool);
DECLARE_REFER_KERNEL(Sgd);
DECLARE_REFER_KERNEL(VBroadcast);
#undef DECLARE_REFER_KERNEL
......
此差异已折叠。
......@@ -229,8 +229,8 @@ class LayerNormKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(scale->numel(), right);
PADDLE_ENFORCE_EQ(bias->numel(), right);
auto ker = jit::KernelFuncs<jit::kLayerNorm, jit::LayerNormTuples<T>,
platform::CPUPlace>::Cache()
auto ker =
jit::KernelFuncs<jit::LayerNormTuple<T>, platform::CPUPlace>::Cache()
.At(right);
ker(x.data<T>(), out.data<T>(), mean->data<T>(), var->data<T>(),
scale->data<T>(), bias->data<T>(), static_cast<int>(left),
......
......@@ -30,17 +30,16 @@ inline void FCCompute(const BlasT<DeviceContext, T>& blas, const int M,
return;
}
if (relu) {
auto compute = jit::KernelFuncs<jit::kVAddRelu, jit::XYZNTuples<T>,
platform::CPUPlace>::Cache()
.At(N);
auto compute =
jit::KernelFuncs<jit::VAddReluTuple<T>, platform::CPUPlace>::Cache().At(
N);
for (int i = 0; i < M; i++) {
T* dst = Y + i * N;
compute(B, dst, dst, N);
}
} else {
auto compute = jit::KernelFuncs<jit::kVAdd, jit::XYZNTuples<T>,
platform::CPUPlace>::Cache()
.At(N);
auto compute =
jit::KernelFuncs<jit::VAddTuple<T>, platform::CPUPlace>::Cache().At(N);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
......
......@@ -255,8 +255,8 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
jit::seq_pool_attr_t attr(
static_cast<int>(input.numel() / input.dims()[0]),
jit::SeqPoolType::kSum);
auto seqpool = jit::KernelFuncs<jit::kSeqPool, jit::SeqPoolTuples<T>,
platform::CPUPlace>::Cache()
auto seqpool =
jit::KernelFuncs<jit::SeqPoolTuple<T>, platform::CPUPlace>::Cache()
.At(attr);
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
attr.h = static_cast<int>(lod[i + 1] - lod[i]);
......
......@@ -82,8 +82,7 @@ class SoftmaxFunctor<DeviceContext, float, true, enable_if_CPU<DeviceContext>> {
const int kClassDim = 1;
// 2D data. Batch x C
auto compute_softmax =
jit::KernelFuncs<jit::kSoftmax, jit::SoftmaxTuples<float>,
platform::CPUPlace>::Cache()
jit::KernelFuncs<jit::SoftmaxTuple<float>, platform::CPUPlace>::Cache()
.At(in_dims[kClassDim]);
compute_softmax(in_data, out_data, in_dims[kClassDim], in_dims[kBatchDim]);
}
......
......@@ -47,9 +47,9 @@ class SGDOpKernel : public framework::OpKernel<T> {
int64_t rows_idx = 0;
T *out_data = param_out->mutable_data<T>(ctx.GetPlace());
auto sgd = jit::KernelFuncs<jit::kSgd, jit::SgdTuples<T>,
platform::CPUPlace>::Cache()
.At(attr);
auto sgd =
jit::KernelFuncs<jit::SgdTuple<T>, platform::CPUPlace>::Cache().At(
attr);
sgd(lr, param_data, grad_data, &rows_idx, out_data, &attr);
} else if (grad_var->IsType<framework::SelectedRows>()) {
// TODO(qijun): In Sparse SGD operator, in-place update is enforced.
......@@ -82,9 +82,9 @@ class SGDOpKernel : public framework::OpKernel<T> {
attr.selected_rows_size = grad_rows.size();
PADDLE_ENFORCE_EQ(attr.grad_width, attr.param_width);
auto sgd = jit::KernelFuncs<jit::kSgd, jit::SgdTuples<T>,
platform::CPUPlace>::Cache()
.At(attr);
auto sgd =
jit::KernelFuncs<jit::SgdTuple<T>, platform::CPUPlace>::Cache().At(
attr);
sgd(lr, param_data, grad_data, rows_data, out_data, &attr);
} else {
PADDLE_THROW("Unsupported Variable Type of Grad");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册