未验证 提交 2d36c9a9 编写于 作者: G gouzil 提交者: GitHub

【Hackathon No.70】[PHI decoupling] move jit kernels from fluid to phi (#50911)

* [phi] move jit kernels from fluid to phi

* [phi] fix paddle::phi err

* [phi] fix windows 'posix_memalign': identifier not found

* [phi] fix windows 'posix_memalign_free': identifier not found

* [phi] fix readme directory structure, fc_functor  paddle::platform
上级 3b7c9ffc
...@@ -24,7 +24,6 @@ add_subdirectory(optimizers) ...@@ -24,7 +24,6 @@ add_subdirectory(optimizers)
add_subdirectory(reduce_ops) add_subdirectory(reduce_ops)
add_subdirectory(sequence_ops) add_subdirectory(sequence_ops)
add_subdirectory(string) add_subdirectory(string)
add_subdirectory(jit)
add_subdirectory(prim_ops) add_subdirectory(prim_ops)
......
...@@ -17,7 +17,7 @@ limitations under the License. */ ...@@ -17,7 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/jit/kernels.h" #include "paddle/phi/kernels/funcs/jit/kernels.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle { namespace paddle {
...@@ -137,9 +137,9 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> { ...@@ -137,9 +137,9 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> {
phi::DenseTensor track; phi::DenseTensor track;
int* track_value = int* track_value =
track.mutable_data<int>(emission_dims, platform::CPUPlace()); track.mutable_data<int>(emission_dims, platform::CPUPlace());
auto ker = auto ker = phi::jit::KernelFuncs<phi::jit::CRFDecodingTuple<T>,
jit::KernelFuncs<jit::CRFDecodingTuple<T>, platform::CPUPlace>::Cache() platform::CPUPlace>::Cache()
.At(tag_num); .At(tag_num);
ker(static_cast<int>(seq_len), x, w, alpha_value, track_value, tag_num); ker(static_cast<int>(seq_len), x, w, alpha_value, track_value, tag_num);
T max_score = -std::numeric_limits<T>::max(); T max_score = -std::numeric_limits<T>::max();
int max_i = 0; int max_i = 0;
......
...@@ -22,8 +22,8 @@ limitations under the License. */ ...@@ -22,8 +22,8 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/selected_rows_utils.h" #include "paddle/fluid/framework/selected_rows_utils.h"
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/jit/kernels.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -108,17 +108,17 @@ struct EmbeddingVSumFunctor { ...@@ -108,17 +108,17 @@ struct EmbeddingVSumFunctor {
"But received the ids's LoD[0] = %d.", "But received the ids's LoD[0] = %d.",
ids_lod.size())); ids_lod.size()));
jit::emb_seq_pool_attr_t attr(table_height, phi::jit::emb_seq_pool_attr_t attr(table_height,
table_width, table_width,
0, 0,
idx_width, idx_width,
out_width, out_width,
jit::SeqPoolType::kSum); phi::jit::SeqPoolType::kSum);
for (size_t i = 0; i != ids_lod.size() - 1; ++i) { for (size_t i = 0; i != ids_lod.size() - 1; ++i) {
attr.index_height = ids_lod[i + 1] - ids_lod[i]; attr.index_height = ids_lod[i + 1] - ids_lod[i];
auto emb_seqpool = auto emb_seqpool = phi::jit::KernelFuncs<phi::jit::EmbSeqPoolTuple<T>,
jit::KernelFuncs<jit::EmbSeqPoolTuple<T>, platform::CPUPlace>::Cache() platform::CPUPlace>::Cache()
.At(attr); .At(attr);
emb_seqpool( emb_seqpool(
table, ids + ids_lod[i] * idx_width, output + i * out_width, &attr); table, ids + ids_lod[i] * idx_width, output + i * out_width, &attr);
} }
...@@ -265,9 +265,9 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> { ...@@ -265,9 +265,9 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> {
T *d_table_data = d_table_value->mutable_data<T>(context.GetPlace()); T *d_table_data = d_table_value->mutable_data<T>(context.GetPlace());
const T *d_output_data = d_output->data<T>(); const T *d_output_data = d_output->data<T>();
auto vbroadcast = auto vbroadcast = phi::jit::KernelFuncs<phi::jit::VBroadcastTuple<T>,
jit::KernelFuncs<jit::VBroadcastTuple<T>, platform::CPUPlace>::Cache() platform::CPUPlace>::Cache()
.At(out_width); .At(out_width);
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) { for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]); int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]);
const T *src = d_output_data + i * out_width; const T *src = d_output_data + i * out_width;
......
...@@ -19,9 +19,9 @@ limitations under the License. */ ...@@ -19,9 +19,9 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/fc_functor.h" #include "paddle/phi/kernels/funcs/fc_functor.h"
#include "paddle/phi/kernels/funcs/jit/kernels.h"
#include "paddle/phi/kernels/funcs/sequence2batch.h" #include "paddle/phi/kernels/funcs/sequence2batch.h"
namespace paddle { namespace paddle {
...@@ -273,33 +273,33 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -273,33 +273,33 @@ class FusionGRUKernel : public framework::OpKernel<T> {
const int total_T = x_mat_dims[0]; \ const int total_T = x_mat_dims[0]; \
const int D3 = wh_dims[1] const int D3 = wh_dims[1]
#define INIT_OTHER_DEFINES \ #define INIT_OTHER_DEFINES \
auto* h0 = ctx.Input<phi::DenseTensor>("H0"); \ auto* h0 = ctx.Input<phi::DenseTensor>("H0"); \
auto* wx = ctx.Input<phi::DenseTensor>("WeightX"); \ auto* wx = ctx.Input<phi::DenseTensor>("WeightX"); \
auto* bias = ctx.Input<phi::DenseTensor>("Bias"); \ auto* bias = ctx.Input<phi::DenseTensor>("Bias"); \
auto* hidden_out = ctx.Output<phi::DenseTensor>("Hidden"); \ auto* hidden_out = ctx.Output<phi::DenseTensor>("Hidden"); \
bool is_reverse = ctx.Attr<bool>("is_reverse"); \ bool is_reverse = ctx.Attr<bool>("is_reverse"); \
const int M = x_mat_dims[1]; \ const int M = x_mat_dims[1]; \
const int D = wh_dims[0]; \ const int D = wh_dims[0]; \
const int D2 = D * 2; \ const int D2 = D * 2; \
const jit::gru_attr_t attr( \ const phi::jit::gru_attr_t attr( \
D, \ D, \
jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")), \ phi::jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")), \
jit::to_kerneltype(ctx.Attr<std::string>("activation"))); \ phi::jit::to_kerneltype(ctx.Attr<std::string>("activation"))); \
jit::gru_t one_step; \ phi::jit::gru_t one_step; \
auto ComputeH1 = \ auto ComputeH1 = phi::jit::KernelFuncs<phi::jit::GRUH1Tuple<T>, \
jit::KernelFuncs<jit::GRUH1Tuple<T>, platform::CPUPlace>::Cache().At( \ platform::CPUPlace>::Cache() \
attr); \ .At(attr); \
auto ComputeHtPart1 = \ auto ComputeHtPart1 = phi::jit::KernelFuncs<phi::jit::GRUHtPart1Tuple<T>, \
jit::KernelFuncs<jit::GRUHtPart1Tuple<T>, platform::CPUPlace>::Cache() \ platform::CPUPlace>::Cache() \
.At(attr); \ .At(attr); \
auto ComputeHtPart2 = \ auto ComputeHtPart2 = phi::jit::KernelFuncs<phi::jit::GRUHtPart2Tuple<T>, \
jit::KernelFuncs<jit::GRUHtPart2Tuple<T>, platform::CPUPlace>::Cache() \ platform::CPUPlace>::Cache() \
.At(attr); \ .At(attr); \
const T* x_data = x->data<T>(); \ const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \ const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \ const T* wh_data = wh->data<T>(); \
auto place = ctx.GetPlace(); \ auto place = ctx.GetPlace(); \
T* xx_data = xx->mutable_data<T>(place) T* xx_data = xx->mutable_data<T>(place)
void SeqCompute(const framework::ExecutionContext& ctx) const { void SeqCompute(const framework::ExecutionContext& ctx) const {
......
...@@ -16,9 +16,9 @@ limitations under the License. */ ...@@ -16,9 +16,9 @@ limitations under the License. */
#include <string> #include <string>
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/fc_functor.h" #include "paddle/phi/kernels/funcs/fc_functor.h"
#include "paddle/phi/kernels/funcs/jit/kernels.h"
#include "paddle/phi/kernels/funcs/sequence2batch.h" #include "paddle/phi/kernels/funcs/sequence2batch.h"
namespace paddle { namespace paddle {
...@@ -320,35 +320,35 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -320,35 +320,35 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
const int D = wh_dims[0]; \ const int D = wh_dims[0]; \
const int D4 = wh_dims[1] const int D4 = wh_dims[1]
#define INIT_OTHER_DEFINES \ #define INIT_OTHER_DEFINES \
const T* x_data = x->data<T>(); \ const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \ const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \ const T* wh_data = wh->data<T>(); \
/* diagonal weight*/ \ /* diagonal weight*/ \
const T* wp_data = bias->data<T>() + D4; \ const T* wp_data = bias->data<T>() + D4; \
/* for peephole only*/ \ /* for peephole only*/ \
T* checked_cell_data = nullptr; \ T* checked_cell_data = nullptr; \
auto place = ctx.GetPlace(); \ auto place = ctx.GetPlace(); \
if (use_peepholes) { \ if (use_peepholes) { \
/* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \ /* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \
auto* checked_cell = ctx.Output<phi::DenseTensor>("CheckedCell"); \ auto* checked_cell = ctx.Output<phi::DenseTensor>("CheckedCell"); \
checked_cell_data = checked_cell->mutable_data<T>(place); \ checked_cell_data = checked_cell->mutable_data<T>(place); \
} \ } \
const jit::lstm_attr_t attr( \ const phi::jit::lstm_attr_t attr( \
D, \ D, \
jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")), \ phi::jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")), \
jit::to_kerneltype(ctx.Attr<std::string>("candidate_activation")), \ phi::jit::to_kerneltype(ctx.Attr<std::string>("candidate_activation")), \
jit::to_kerneltype(ctx.Attr<std::string>("cell_activation")), \ phi::jit::to_kerneltype(ctx.Attr<std::string>("cell_activation")), \
use_peepholes); \ use_peepholes); \
jit::lstm_t one_step; \ phi::jit::lstm_t one_step; \
one_step.wp = wp_data; \ one_step.wp = wp_data; \
one_step.checked = checked_cell_data; \ one_step.checked = checked_cell_data; \
auto ComputeC1H1 = \ auto ComputeC1H1 = phi::jit::KernelFuncs<phi::jit::LSTMC1H1Tuple<T>, \
jit::KernelFuncs<jit::LSTMC1H1Tuple<T>, platform::CPUPlace>::Cache().At( \ platform::CPUPlace>::Cache() \
attr); \ .At(attr); \
auto ComputeCtHt = \ auto ComputeCtHt = phi::jit::KernelFuncs<phi::jit::LSTMCtHtTuple<T>, \
jit::KernelFuncs<jit::LSTMCtHtTuple<T>, platform::CPUPlace>::Cache().At( \ platform::CPUPlace>::Cache() \
attr) .At(attr)
// Wh GEMM // Wh GEMM
#define GEMM_WH_ADDON(bs, prev, out) \ #define GEMM_WH_ADDON(bs, prev, out) \
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/operators/jit/kernels.h" #include "paddle/phi/kernels/funcs/jit/kernels.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -122,14 +122,17 @@ void FusionRepeatedFCReluOpMaker::Make() { ...@@ -122,14 +122,17 @@ void FusionRepeatedFCReluOpMaker::Make() {
} }
template <typename T> template <typename T>
static void fc_relu( static void fc_relu(const T* x,
const T* x, const T* w, const T* b, T* y, const jit::matmul_attr_t& attr) { const T* w,
auto matmul = const T* b,
jit::KernelFuncs<jit::MatMulTuple<T>, platform::CPUPlace>::Cache().At( T* y,
attr); const phi::jit::matmul_attr_t& attr) {
auto addbias_relu = auto matmul = phi::jit::KernelFuncs<phi::jit::MatMulTuple<T>,
jit::KernelFuncs<jit::VAddReluTuple<T>, platform::CPUPlace>::Cache().At( platform::CPUPlace>::Cache()
attr.n); .At(attr);
auto addbias_relu = phi::jit::KernelFuncs<phi::jit::VAddReluTuple<T>,
platform::CPUPlace>::Cache()
.At(attr.n);
matmul(x, w, y, &attr); matmul(x, w, y, &attr);
T* dst = y; T* dst = y;
for (int i = 0; i < attr.m; ++i) { for (int i = 0; i < attr.m; ++i) {
...@@ -152,7 +155,7 @@ class FusionRepeatedFCReluKernel : public framework::OpKernel<T> { ...@@ -152,7 +155,7 @@ class FusionRepeatedFCReluKernel : public framework::OpKernel<T> {
auto i_dims = in->dims(); auto i_dims = in->dims();
const auto& w_dims = weights[0]->dims(); const auto& w_dims = weights[0]->dims();
jit::matmul_attr_t attr; phi::jit::matmul_attr_t attr;
attr.m = i_dims[0]; attr.m = i_dims[0];
attr.n = w_dims[1]; attr.n = w_dims[1];
attr.k = w_dims[0]; attr.k = w_dims[0];
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/operators/jit/kernels.h" #include "paddle/phi/kernels/funcs/jit/kernels.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -121,15 +121,15 @@ class FusionSeqPoolConcatKernel : public framework::OpKernel<T> { ...@@ -121,15 +121,15 @@ class FusionSeqPoolConcatKernel : public framework::OpKernel<T> {
"dims[1] is %d, w is %d.", "dims[1] is %d, w is %d.",
y_dims[1], y_dims[1],
w)); w));
jit::seq_pool_attr_t attr(w, jit::SeqPoolType::kSum); phi::jit::seq_pool_attr_t attr(w, phi::jit::SeqPoolType::kSum);
if (pooltype == "AVERAGE") { if (pooltype == "AVERAGE") {
attr.type = jit::SeqPoolType::kAvg; attr.type = phi::jit::SeqPoolType::kAvg;
} else if (pooltype == "SQRT") { } else if (pooltype == "SQRT") {
attr.type = jit::SeqPoolType::kSqrt; attr.type = phi::jit::SeqPoolType::kSqrt;
} }
auto seqpool = auto seqpool = phi::jit::KernelFuncs<phi::jit::SeqPoolTuple<T>,
jit::KernelFuncs<jit::SeqPoolTuple<T>, platform::CPUPlace>::Cache().At( platform::CPUPlace>::Cache()
attr); .At(attr);
size_t n = ins.size(); size_t n = ins.size();
size_t dst_step_size = n * w; size_t dst_step_size = n * w;
for (size_t i = 0; i < n; ++i) { for (size_t i = 0; i < n; ++i) {
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/operators/jit/kernels.h" #include "paddle/phi/kernels/funcs/jit/kernels.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -122,15 +122,15 @@ class FusionSeqPoolCVMConcatKernel : public framework::OpKernel<T> { ...@@ -122,15 +122,15 @@ class FusionSeqPoolCVMConcatKernel : public framework::OpKernel<T> {
0, 0,
paddle::platform::errors::InvalidArgument( paddle::platform::errors::InvalidArgument(
"The output of dims[1] should be dividable of w")); "The output of dims[1] should be dividable of w"));
jit::seq_pool_attr_t attr(w, jit::SeqPoolType::kSum); phi::jit::seq_pool_attr_t attr(w, phi::jit::SeqPoolType::kSum);
if (pooltype == "AVERAGE") { if (pooltype == "AVERAGE") {
attr.type = jit::SeqPoolType::kAvg; attr.type = phi::jit::SeqPoolType::kAvg;
} else if (pooltype == "SQRT") { } else if (pooltype == "SQRT") {
attr.type = jit::SeqPoolType::kSqrt; attr.type = phi::jit::SeqPoolType::kSqrt;
} }
auto seqpool = auto seqpool = phi::jit::KernelFuncs<phi::jit::SeqPoolTuple<T>,
jit::KernelFuncs<jit::SeqPoolTuple<T>, platform::CPUPlace>::Cache().At( platform::CPUPlace>::Cache()
attr); .At(attr);
size_t n = ins.size(); size_t n = ins.size();
size_t dst_step_size = n * w; size_t dst_step_size = n * w;
for (size_t i = 0; i < n; ++i) { for (size_t i = 0; i < n; ++i) {
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/operators/jit/kernels.h" #include "paddle/phi/kernels/funcs/jit/kernels.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -99,30 +99,30 @@ class FusionSquaredMatSubKernel : public framework::OpKernel<T> { ...@@ -99,30 +99,30 @@ class FusionSquaredMatSubKernel : public framework::OpKernel<T> {
auto x_dims = x->dims(); auto x_dims = x->dims();
auto y_dims = y->dims(); auto y_dims = y->dims();
jit::matmul_attr_t attr; phi::jit::matmul_attr_t attr;
attr.m = x_dims[0]; attr.m = x_dims[0];
attr.k = x_dims[1]; attr.k = x_dims[1];
attr.n = y_dims[1]; attr.n = y_dims[1];
int o_numel = attr.m * attr.n; int o_numel = attr.m * attr.n;
auto vsquare_x = auto vsquare_x = phi::jit::KernelFuncs<phi::jit::VSquareTuple<T>,
jit::KernelFuncs<jit::VSquareTuple<T>, platform::CPUPlace>::Cache().At( platform::CPUPlace>::Cache()
attr.m * attr.k); .At(attr.m * attr.k);
auto vsquare_y = auto vsquare_y = phi::jit::KernelFuncs<phi::jit::VSquareTuple<T>,
jit::KernelFuncs<jit::VSquareTuple<T>, platform::CPUPlace>::Cache().At( platform::CPUPlace>::Cache()
attr.k * attr.n); .At(attr.k * attr.n);
auto vsquare_xy = auto vsquare_xy = phi::jit::KernelFuncs<phi::jit::VSquareTuple<T>,
jit::KernelFuncs<jit::VSquareTuple<T>, platform::CPUPlace>::Cache().At( platform::CPUPlace>::Cache()
o_numel); .At(o_numel);
auto vsub = auto vsub = phi::jit::KernelFuncs<phi::jit::VSubTuple<T>,
jit::KernelFuncs<jit::VSubTuple<T>, platform::CPUPlace>::Cache().At( platform::CPUPlace>::Cache()
o_numel); .At(o_numel);
auto vscal = auto vscal = phi::jit::KernelFuncs<phi::jit::VScalTuple<T>,
jit::KernelFuncs<jit::VScalTuple<T>, platform::CPUPlace>::Cache().At( platform::CPUPlace>::Cache()
o_numel); .At(o_numel);
auto matmul = auto matmul = phi::jit::KernelFuncs<phi::jit::MatMulTuple<T>,
jit::KernelFuncs<jit::MatMulTuple<T>, platform::CPUPlace>::Cache().At( platform::CPUPlace>::Cache()
attr); .At(attr);
const T* x_data = x->data<T>(); const T* x_data = x->data<T>();
const T* y_data = y->data<T>(); const T* y_data = y->data<T>();
......
...@@ -18,7 +18,6 @@ limitations under the License. */ ...@@ -18,7 +18,6 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/fc_functor.h" #include "paddle/phi/kernels/funcs/fc_functor.h"
#include "paddle/phi/kernels/funcs/sequence2batch.h" #include "paddle/phi/kernels/funcs/sequence2batch.h"
......
...@@ -16,8 +16,8 @@ limitations under the License. */ ...@@ -16,8 +16,8 @@ limitations under the License. */
#include <string> #include <string>
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/jit/kernels.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle { namespace paddle {
...@@ -382,12 +382,12 @@ class SequencePoolFunctor<phi::CPUContext, T> { ...@@ -382,12 +382,12 @@ class SequencePoolFunctor<phi::CPUContext, T> {
"Sequence_pool should run on CPU Device when pooltype is SUM")); "Sequence_pool should run on CPU Device when pooltype is SUM"));
const T* src = input.data<T>(); const T* src = input.data<T>();
T* dst = output->mutable_data<T>(place); T* dst = output->mutable_data<T>(place);
jit::seq_pool_attr_t attr( phi::jit::seq_pool_attr_t attr(
static_cast<int>(input.numel() / input.dims()[0]), static_cast<int>(input.numel() / input.dims()[0]),
jit::SeqPoolType::kSum); phi::jit::SeqPoolType::kSum);
auto seqpool = auto seqpool = phi::jit::KernelFuncs<phi::jit::SeqPoolTuple<T>,
jit::KernelFuncs<jit::SeqPoolTuple<T>, platform::CPUPlace>::Cache() platform::CPUPlace>::Cache()
.At(attr); .At(attr);
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) { for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
attr.h = static_cast<int>(lod[i + 1] - lod[i]); attr.h = static_cast<int>(lod[i + 1] - lod[i]);
if (attr.h == 0) { if (attr.h == 0) {
......
...@@ -18,8 +18,8 @@ limitations under the License. */ ...@@ -18,8 +18,8 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/selected_rows_utils.h" #include "paddle/fluid/framework/selected_rows_utils.h"
#include "paddle/fluid/framework/var_type_traits.h" #include "paddle/fluid/framework/var_type_traits.h"
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/bfloat16.h"
#include "paddle/phi/kernels/funcs/jit/kernels.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -43,16 +43,16 @@ struct sgd_dense_param_kernel<T, ...@@ -43,16 +43,16 @@ struct sgd_dense_param_kernel<T,
const auto *grad = ctx.Input<phi::DenseTensor>("Grad"); const auto *grad = ctx.Input<phi::DenseTensor>("Grad");
const auto sz = param_out->numel(); const auto sz = param_out->numel();
jit::sgd_attr_t attr(1, sz, 1, sz, 1); phi::jit::sgd_attr_t attr(1, sz, 1, sz, 1);
const T *lr = learning_rate->data<T>(); const T *lr = learning_rate->data<T>();
const T *param_data = param->data<T>(); const T *param_data = param->data<T>();
const T *grad_data = grad->data<T>(); const T *grad_data = grad->data<T>();
int64_t rows_idx = 0; int64_t rows_idx = 0;
T *out_data = param_out->mutable_data<T>(ctx.GetPlace()); T *out_data = param_out->mutable_data<T>(ctx.GetPlace());
auto sgd = auto sgd = phi::jit::KernelFuncs<phi::jit::SgdTuple<T>,
jit::KernelFuncs<jit::SgdTuple<T>, platform::CPUPlace>::Cache().At( platform::CPUPlace>::Cache()
attr); .At(attr);
sgd(lr, param_data, grad_data, &rows_idx, out_data, &attr); sgd(lr, param_data, grad_data, &rows_idx, out_data, &attr);
} }
}; };
...@@ -76,16 +76,16 @@ struct sgd_dense_param_kernel<T, ...@@ -76,16 +76,16 @@ struct sgd_dense_param_kernel<T,
const int64_t *rows_data = grad_rows.data(); const int64_t *rows_data = grad_rows.data();
T *out_data = param_out->mutable_data<T>(ctx.GetPlace()); T *out_data = param_out->mutable_data<T>(ctx.GetPlace());
jit::sgd_attr_t attr; phi::jit::sgd_attr_t attr;
attr.param_height = param_out->dims()[0]; attr.param_height = param_out->dims()[0];
attr.param_width = param_out->numel() / attr.param_height; attr.param_width = param_out->numel() / attr.param_height;
attr.grad_height = grad_rows.size(); // note: it is not grad->height() attr.grad_height = grad_rows.size(); // note: it is not grad->height()
attr.grad_width = grad_value.numel() / attr.grad_height; attr.grad_width = grad_value.numel() / attr.grad_height;
attr.selected_rows_size = grad_rows.size(); attr.selected_rows_size = grad_rows.size();
auto sgd = auto sgd = phi::jit::KernelFuncs<phi::jit::SgdTuple<T>,
jit::KernelFuncs<jit::SgdTuple<T>, platform::CPUPlace>::Cache().At( platform::CPUPlace>::Cache()
attr); .At(attr);
sgd(lr, param_data, grad_data, rows_data, out_data, &attr); sgd(lr, param_data, grad_data, rows_data, out_data, &attr);
} }
}; };
......
...@@ -16,11 +16,11 @@ ...@@ -16,11 +16,11 @@
#include <vector> #include <vector>
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/adam_functors.h" #include "paddle/phi/kernels/funcs/adam_functors.h"
#include "paddle/phi/kernels/funcs/jit/kernels.h"
DECLARE_int32(inner_op_parallelism); DECLARE_int32(inner_op_parallelism);
...@@ -114,7 +114,7 @@ void AdamDenseKernel(const Context& dev_ctx, ...@@ -114,7 +114,7 @@ void AdamDenseKernel(const Context& dev_ctx,
learning_rate.data<T>()[0] * (sqrt(1 - beta2_p) / (1 - beta1_p)); learning_rate.data<T>()[0] * (sqrt(1 - beta2_p) / (1 - beta1_p));
T eps = epsilon_ * sqrt(1 - beta2_p); T eps = epsilon_ * sqrt(1 - beta2_p);
paddle::operators::jit::adam_attr_t attr(beta1_, beta2_); phi::jit::adam_attr_t attr(beta1_, beta2_);
int64_t numel = param.numel(); int64_t numel = param.numel();
const T* param_ptr = param.data<T>(); const T* param_ptr = param.data<T>();
...@@ -123,9 +123,8 @@ void AdamDenseKernel(const Context& dev_ctx, ...@@ -123,9 +123,8 @@ void AdamDenseKernel(const Context& dev_ctx,
const T* grad_ptr = grad.data<T>(); const T* grad_ptr = grad.data<T>();
auto adam = auto adam =
paddle::operators::jit::KernelFuncs<paddle::operators::jit::AdamTuple<T>, phi::jit::KernelFuncs<phi::jit::AdamTuple<T>, phi::CPUPlace>::Cache().At(
phi::CPUPlace>::Cache() attr);
.At(attr);
static constexpr int64_t chunk_size = 512; static constexpr int64_t chunk_size = 512;
......
...@@ -16,13 +16,13 @@ ...@@ -16,13 +16,13 @@
#include <vector> #include <vector>
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/float16.h" #include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/adam_kernel.h" #include "paddle/phi/kernels/adam_kernel.h"
#include "paddle/phi/kernels/funcs/adam_functors.h" #include "paddle/phi/kernels/funcs/adam_functors.h"
#include "paddle/phi/kernels/funcs/jit/kernels.h"
namespace phi { namespace phi {
...@@ -141,9 +141,8 @@ void AdamwDenseKernel(const Context& dev_ctx, ...@@ -141,9 +141,8 @@ void AdamwDenseKernel(const Context& dev_ctx,
const T* grad_ptr = grad.data<T>(); const T* grad_ptr = grad.data<T>();
auto adamw = auto adamw =
paddle::operators::jit::KernelFuncs<paddle::operators::jit::AdamWTuple<T>, phi::jit::KernelFuncs<phi::jit::AdamWTuple<T>, phi::CPUPlace>::Cache().At(
phi::CPUPlace>::Cache() 1);
.At(1);
static constexpr int64_t chunk_size = 512; static constexpr int64_t chunk_size = 512;
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include "paddle/phi/kernels/funcs/layer_norm_util.h" #include "paddle/phi/kernels/funcs/layer_norm_util.h"
#if !defined(PADDLE_WITH_CUDA) && !defined(_WIN32) && !defined(__APPLE__) && \ #if !defined(PADDLE_WITH_CUDA) && !defined(_WIN32) && !defined(__APPLE__) && \
!defined(__OSX__) !defined(__OSX__)
#include "paddle/fluid/operators/jit/kernels.h" #include "paddle/phi/kernels/funcs/jit/kernels.h"
#endif #endif
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include "paddle/phi/kernels/funcs/layer_norm_util.h" #include "paddle/phi/kernels/funcs/layer_norm_util.h"
#if !defined(PADDLE_WITH_CUDA) && !defined(_WIN32) && !defined(__APPLE__) && \ #if !defined(PADDLE_WITH_CUDA) && !defined(_WIN32) && !defined(__APPLE__) && \
!defined(__OSX__) !defined(__OSX__)
#include "paddle/fluid/operators/jit/kernels.h" #include "paddle/phi/kernels/funcs/jit/kernels.h"
#endif #endif
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
...@@ -123,10 +123,9 @@ void LayerNormKernel(const Context& dev_ctx, ...@@ -123,10 +123,9 @@ void LayerNormKernel(const Context& dev_ctx,
right)); right));
} }
auto ker = paddle::operators::jit::KernelFuncs< auto ker =
paddle::operators::jit::LayerNormTuple<T>, phi::jit::KernelFuncs<phi::jit::LayerNormTuple<T>, phi::CPUPlace>::Cache()
phi::CPUPlace>::Cache() .At(right);
.At(right);
ker(x_tmp.data<T>(), ker(x_tmp.data<T>(),
out.data<T>(), out.data<T>(),
mean->data<T>(), mean->data<T>(),
......
...@@ -14,10 +14,10 @@ ...@@ -14,10 +14,10 @@
#include "paddle/phi/kernels/sgd_kernel.h" #include "paddle/phi/kernels/sgd_kernel.h"
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/jit/kernels.h"
namespace phi { namespace phi {
...@@ -27,7 +27,7 @@ void sgd_dense_param_dense_grad_impl(const DenseTensor& param, ...@@ -27,7 +27,7 @@ void sgd_dense_param_dense_grad_impl(const DenseTensor& param,
const DenseTensor& grad, const DenseTensor& grad,
DenseTensor* param_out) { DenseTensor* param_out) {
const auto sz = param_out->numel(); const auto sz = param_out->numel();
paddle::operators::jit::sgd_attr_t attr(1, sz, 1, sz, 1); phi::jit::sgd_attr_t attr(1, sz, 1, sz, 1);
const T* lr = learning_rate.data<T>(); const T* lr = learning_rate.data<T>();
const T* param_data = param.data<T>(); const T* param_data = param.data<T>();
const T* grad_data = grad.data<T>(); const T* grad_data = grad.data<T>();
...@@ -35,9 +35,8 @@ void sgd_dense_param_dense_grad_impl(const DenseTensor& param, ...@@ -35,9 +35,8 @@ void sgd_dense_param_dense_grad_impl(const DenseTensor& param,
T* out_data = param_out->data<T>(); T* out_data = param_out->data<T>();
auto sgd = auto sgd =
paddle::operators::jit::KernelFuncs<paddle::operators::jit::SgdTuple<T>, phi::jit::KernelFuncs<phi::jit::SgdTuple<T>, phi::CPUPlace>::Cache().At(
phi::CPUPlace>::Cache() attr);
.At(attr);
sgd(lr, param_data, grad_data, &rows_idx, out_data, &attr); sgd(lr, param_data, grad_data, &rows_idx, out_data, &attr);
} }
...@@ -68,7 +67,7 @@ void sgd_dense_param_sparse_grad_impl(const DenseTensor& param, ...@@ -68,7 +67,7 @@ void sgd_dense_param_sparse_grad_impl(const DenseTensor& param,
const int64_t* rows_data = grad_rows.data(); const int64_t* rows_data = grad_rows.data();
T* out_data = param_out->data<T>(); T* out_data = param_out->data<T>();
paddle::operators::jit::sgd_attr_t attr; phi::jit::sgd_attr_t attr;
attr.param_height = param_out->dims()[0]; attr.param_height = param_out->dims()[0];
attr.param_width = param_out->numel() / attr.param_height; attr.param_width = param_out->numel() / attr.param_height;
attr.grad_height = grad_rows.size(); // note: it is not grad->height() attr.grad_height = grad_rows.size(); // note: it is not grad->height()
...@@ -76,9 +75,8 @@ void sgd_dense_param_sparse_grad_impl(const DenseTensor& param, ...@@ -76,9 +75,8 @@ void sgd_dense_param_sparse_grad_impl(const DenseTensor& param,
attr.selected_rows_size = grad_rows.size(); attr.selected_rows_size = grad_rows.size();
auto sgd = auto sgd =
paddle::operators::jit::KernelFuncs<paddle::operators::jit::SgdTuple<T>, phi::jit::KernelFuncs<phi::jit::SgdTuple<T>, phi::CPUPlace>::Cache().At(
phi::CPUPlace>::Cache() attr);
.At(attr);
sgd(lr, param_data, grad_data, rows_data, out_data, &attr); sgd(lr, param_data, grad_data, rows_data, out_data, &attr);
} }
......
...@@ -2,6 +2,7 @@ add_subdirectory(eigen) ...@@ -2,6 +2,7 @@ add_subdirectory(eigen)
add_subdirectory(blas) add_subdirectory(blas)
add_subdirectory(lapack) add_subdirectory(lapack)
add_subdirectory(detail) add_subdirectory(detail)
add_subdirectory(jit)
math_library(deformable_conv_functor DEPS dense_tensor) math_library(deformable_conv_functor DEPS dense_tensor)
math_library(concat_and_split_functor DEPS dense_tensor) math_library(concat_and_split_functor DEPS dense_tensor)
......
...@@ -14,9 +14,9 @@ limitations under the License. */ ...@@ -14,9 +14,9 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/fc_functor.h" #include "paddle/phi/kernels/funcs/fc_functor.h"
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/jit/kernels.h"
namespace phi { namespace phi {
namespace funcs { namespace funcs {
...@@ -81,13 +81,11 @@ void FCFunctor<DeviceContext, T>::operator()(const DeviceContext& context, ...@@ -81,13 +81,11 @@ void FCFunctor<DeviceContext, T>::operator()(const DeviceContext& context,
errors::PermissionDenied("When bias is NULL, relu can not be true.")); errors::PermissionDenied("When bias is NULL, relu can not be true."));
return; return;
} }
auto compute = relu ? paddle::operators::jit::KernelFuncs< auto compute = relu ? phi::jit::KernelFuncs<phi::jit::VAddReluTuple<T>,
paddle::operators::jit::VAddReluTuple<T>, phi::CPUPlace>::Cache()
paddle::platform::CPUPlace>::Cache()
.At(N) .At(N)
: paddle::operators::jit::KernelFuncs< : phi::jit::KernelFuncs<phi::jit::VAddTuple<T>,
paddle::operators::jit::VAddTuple<T>, phi::CPUPlace>::Cache()
paddle::platform::CPUPlace>::Cache()
.At(N); .At(N);
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
#pragma omp parallel for #pragma omp parallel for
......
set(jit_file ${PADDLE_BINARY_DIR}/paddle/fluid/operators/jit/kernels.h.tmp) set(jit_file ${PADDLE_BINARY_DIR}/paddle/phi/kernels/funcs/jit/kernels.h.tmp)
set(jit_file_final ${PADDLE_BINARY_DIR}/paddle/fluid/operators/jit/kernels.h) set(jit_file_final ${PADDLE_BINARY_DIR}/paddle/phi/kernels/funcs/jit/kernels.h)
file( file(
WRITE ${jit_file} WRITE ${jit_file}
"// Generated by the paddle/fluid/operators/jit/CMakeLists.txt. DO NOT EDIT!\n\n" "// Generated by the paddle/phi/kernels/funcs/jit/CMakeLists.txt. DO NOT EDIT!\n\n"
) )
file(APPEND ${jit_file} "\#pragma once\n") file(APPEND ${jit_file} "\#pragma once\n")
file(APPEND ${jit_file} "\#include \"paddle/fluid/operators/jit/helper.h\"\n") file(APPEND ${jit_file} "\#include \"paddle/phi/kernels/funcs/jit/helper.h\"\n")
file(APPEND ${jit_file} file(APPEND ${jit_file}
"\#include \"paddle/fluid/operators/jit/registry.h\"\n\n") "\#include \"paddle/phi/kernels/funcs/jit/registry.h\"\n\n")
set(JIT_KERNEL_DEPS device_context cblas gflags enforce place xxhash) set(JIT_KERNEL_DEPS device_context cblas gflags enforce place xxhash)
......
...@@ -10,9 +10,9 @@ Currently it's only supported on CPU yet. ...@@ -10,9 +10,9 @@ Currently it's only supported on CPU yet.
## Contents ## Contents
```txt ```txt
PaddlePaddle/Paddle/paddle/fluid/ PaddlePaddle/Paddle/paddle/phi/kernels/
├── ... ├── ...
└── operators/ └── funcs/
├── .../ ├── .../
└── jit/ └── jit/
├── ... ├── ...
...@@ -34,7 +34,7 @@ PaddlePaddle/Paddle/paddle/fluid/ ...@@ -34,7 +34,7 @@ PaddlePaddle/Paddle/paddle/fluid/
└── ... └── ...
``` ```
All basical definitions of jit kernels are addressed in `paddle/fluid/operators/jit` including these three key folders `refer`, `gen`, `more`. There is only one unique name for each kernel while may have seraval implementations with same functionality. All basical definitions of jit kernels are addressed in `paddle/phi/kernels/funcs/jit` including these three key folders `refer`, `gen`, `more`. There is only one unique name for each kernel while may have seraval implementations with same functionality.
- `refer`: Each kernel must have one reference implementation on CPU, and it should only focus on the correctness and should not depends on any third-party libraries. - `refer`: Each kernel must have one reference implementation on CPU, and it should only focus on the correctness and should not depends on any third-party libraries.
- `gen`: The code generated should be kept here. They should be designed focusing on the best performance, which depends on Xbyak. - `gen`: The code generated should be kept here. They should be designed focusing on the best performance, which depends on Xbyak.
...@@ -55,7 +55,7 @@ Get from cache: ...@@ -55,7 +55,7 @@ Get from cache:
```cpp ```cpp
using T = float; using T = float;
jit::seq_pool_attr_t attr(width, jit::SeqPoolType::kSum); jit::seq_pool_attr_t attr(width, jit::SeqPoolType::kSum);
auto seqpool_func = jit::KernelFuncs<jit::SeqPoolTuple<T>, platform::CPUPlace>::Cache().At(attr); auto seqpool_func = jit::KernelFuncs<jit::SeqPoolTuple<T>, phi::CPUPlace>::Cache().At(attr);
seqpool_func(src_data, dst_data, &attr); seqpool_func(src_data, dst_data, &attr);
``` ```
...@@ -64,14 +64,14 @@ Get all implementations and run once: ...@@ -64,14 +64,14 @@ Get all implementations and run once:
```cpp ```cpp
using T = float; using T = float;
jit::seq_pool_attr_t attr(width, jit::SeqPoolType::kSum); jit::seq_pool_attr_t attr(width, jit::SeqPoolType::kSum);
auto funcs = jit::GetAllCandidateFuncsWithTypes<jit::SeqPoolTuple<T>, platform::CPUPlace>(attr); auto funcs = jit::GetAllCandidateFuncsWithTypes<jit::SeqPoolTuple<T>, phi::CPUPlace>(attr);
for (auto f : funcs) { for (auto f : funcs) {
LOG(INFO) << "Kernel implementation type: " << f.first; LOG(INFO) << "Kernel implementation type: " << f.first;
f.second(src_data, dst_data, &attr); f.second(src_data, dst_data, &attr);
} }
``` ```
All kernels are inlcuded in `paddle/fluid/operators/jit/kernels.h`, which is automatically generated in compile time, you can only include this one header to get all the registered kernels. All kernels are inlcuded in `paddle/phi/kernels/funcs/jit/kernels.h`, which is automatically generated in compile time, you can only include this one header to get all the registered kernels.
## Solid Test ## Solid Test
......
...@@ -8,9 +8,9 @@ ...@@ -8,9 +8,9 @@
## 目录结构 ## 目录结构
```txt ```txt
PaddlePaddle/Paddle/paddle/fluid/ PaddlePaddle/Paddle/paddle/phi/kernels/
├── ... ├── ...
└── operators/ └── funcs/
├── .../ ├── .../
└── jit/ └── jit/
├── ... ├── ...
...@@ -46,14 +46,14 @@ PaddlePaddle/Paddle/paddle/fluid/ ...@@ -46,14 +46,14 @@ PaddlePaddle/Paddle/paddle/fluid/
### 例子 ### 例子
所有kernel的调用只需要在头文件中包含`"paddle/fluid/operators/jit/kernels.h"`, 该文件是编译时自动生成的。 所有kernel的调用只需要在头文件中包含`"paddle/phi/kernels/funcs/jit/kernels.h"`, 该文件是编译时自动生成的。
直接从缓存中获取默认最优的函数。 直接从缓存中获取默认最优的函数。
```cpp ```cpp
using T = float; using T = float;
jit::seq_pool_attr_t attr(width, jit::SeqPoolType::kSum); jit::seq_pool_attr_t attr(width, jit::SeqPoolType::kSum);
auto seqpool_func = jit::KernelFuncs<jit::SeqPoolTuple<T>, platform::CPUPlace>::Cache().At(attr); auto seqpool_func = jit::KernelFuncs<jit::SeqPoolTuple<T>, phi::CPUPlace>::Cache().At(attr);
seqpool_func(src_data, dst_data, &attr); seqpool_func(src_data, dst_data, &attr);
``` ```
...@@ -62,7 +62,7 @@ PaddlePaddle/Paddle/paddle/fluid/ ...@@ -62,7 +62,7 @@ PaddlePaddle/Paddle/paddle/fluid/
```cpp ```cpp
using T = float; using T = float;
jit::seq_pool_attr_t attr(width, jit::SeqPoolType::kSum); jit::seq_pool_attr_t attr(width, jit::SeqPoolType::kSum);
auto funcs = jit::GetAllCandidateFuncsWithTypes<jit::SeqPoolTuple<T>, platform::CPUPlace>(attr); auto funcs = jit::GetAllCandidateFuncsWithTypes<jit::SeqPoolTuple<T>, phi::CPUPlace>(attr);
for (auto f : funcs) { for (auto f : funcs) {
LOG(INFO) << "Kernel implementation type: " << f.first; LOG(INFO) << "Kernel implementation type: " << f.first;
f.second(src_data, dst_data, &attr); f.second(src_data, dst_data, &attr);
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -17,11 +17,11 @@ ...@@ -17,11 +17,11 @@
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/api/profiler/device_tracer.h" #include "paddle/phi/api/profiler/device_tracer.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/kernels/funcs/jit/kernels.h"
DEFINE_int32(burning, 10, "Burning times."); DEFINE_int32(burning, 10, "Burning times.");
DEFINE_int32(repeat, 3000, "Repeat times."); DEFINE_int32(repeat, 3000, "Repeat times.");
...@@ -106,7 +106,7 @@ struct BenchFunc { ...@@ -106,7 +106,7 @@ struct BenchFunc {
} }
}; };
namespace jit = paddle::operators::jit; namespace jit = phi::jit;
template <typename KernelTuple, typename PlaceType, typename... Args> template <typename KernelTuple, typename PlaceType, typename... Args>
void BenchAllImpls(const typename KernelTuple::attr_type& attr, Args... args) { void BenchAllImpls(const typename KernelTuple::attr_type& attr, Args... args) {
...@@ -120,8 +120,7 @@ void BenchAllImpls(const typename KernelTuple::attr_type& attr, Args... args) { ...@@ -120,8 +120,7 @@ void BenchAllImpls(const typename KernelTuple::attr_type& attr, Args... args) {
// Test result from Get function // Test result from Get function
auto tgt = jit::KernelFuncs<KernelTuple, PlaceType>::Cache().At(attr); auto tgt = jit::KernelFuncs<KernelTuple, PlaceType>::Cache().At(attr);
if (!tgt) { if (!tgt) {
PADDLE_THROW( PADDLE_THROW(phi::errors::Fatal("Benchmark target can not be empty."));
paddle::platform::errors::Fatal("Benchmark target can not be empty."));
} }
infos.push_back(std::make_pair("Target", benchmark(tgt, args...))); infos.push_back(std::make_pair("Target", benchmark(tgt, args...)));
...@@ -323,7 +322,7 @@ void BenchKernelSgd() { ...@@ -323,7 +322,7 @@ void BenchKernelSgd() {
PADDLE_ENFORCE_LE( PADDLE_ENFORCE_LE(
static_cast<size_t>(upper - lower), static_cast<size_t>(upper - lower),
n - 1, n - 1,
paddle::platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The range of Sgd (upper - lower) should be equal to or lower " "The range of Sgd (upper - lower) should be equal to or lower "
"than n-1 (Sgd size -1). But upper - lower is %d and n-1 is %d.", "than n-1 (Sgd size -1). But upper - lower is %d and n-1 is %d.",
static_cast<size_t>(upper - lower), static_cast<size_t>(upper - lower),
...@@ -331,7 +330,7 @@ void BenchKernelSgd() { ...@@ -331,7 +330,7 @@ void BenchKernelSgd() {
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
n, n,
0, 0,
paddle::platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The Sgd size should be larger than 0. But the n is %d.", n)); "The Sgd size should be larger than 0. But the n is %d.", n));
std::vector<int64_t> all, out; std::vector<int64_t> all, out;
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
...@@ -525,7 +524,7 @@ void BenchKernelVBroadcast() { ...@@ -525,7 +524,7 @@ void BenchKernelVBroadcast() {
#define BenchKernelGRUHtPart1 BenchKernelGRU #define BenchKernelGRUHtPart1 BenchKernelGRU
#define BenchKernelGRUHtPart2 BenchKernelGRU #define BenchKernelGRUHtPart2 BenchKernelGRU
using CPUPlace = paddle::platform::CPUPlace; using CPUPlace = phi::CPUPlace;
#define BENCH_FP32_CPU(name) \ #define BENCH_FP32_CPU(name) \
BENCH_JITKERNEL(name, FP32, CPU) { \ BENCH_JITKERNEL(name, FP32, CPU) { \
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -12,13 +12,12 @@ ...@@ -12,13 +12,12 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. */ * limitations under the License. */
#include "paddle/fluid/operators/jit/gen/act.h" #include "paddle/phi/kernels/funcs/jit/gen/act.h"
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/phi/backends/cpu/cpu_info.h" #include "paddle/phi/backends/cpu/cpu_info.h"
#include "paddle/phi/kernels/funcs/jit/registry.h"
namespace paddle { namespace phi {
namespace operators {
namespace jit { namespace jit {
namespace gen { namespace gen {
...@@ -150,10 +149,9 @@ size_t VTanhCreator::CodeSize(const int& d) const { ...@@ -150,10 +149,9 @@ size_t VTanhCreator::CodeSize(const int& d) const {
} // namespace gen } // namespace gen
} // namespace jit } // namespace jit
} // namespace operators } // namespace phi
} // namespace paddle
namespace gen = paddle::operators::jit::gen; namespace gen = phi::jit::gen;
REGISTER_JITKERNEL_GEN(kVRelu, gen::VReluCreator); REGISTER_JITKERNEL_GEN(kVRelu, gen::VReluCreator);
REGISTER_JITKERNEL_GEN(kVSquare, gen::VSquareCreator); REGISTER_JITKERNEL_GEN(kVSquare, gen::VSquareCreator);
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -17,11 +17,10 @@ ...@@ -17,11 +17,10 @@
#include <string> #include <string>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h" #include "paddle/phi/core/enforce.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/phi/kernels/funcs/jit/gen/jitcode.h"
namespace paddle { namespace phi {
namespace operators {
namespace jit { namespace jit {
namespace gen { namespace gen {
...@@ -91,7 +90,7 @@ class VActFunc : public JitCode { ...@@ -91,7 +90,7 @@ class VActFunc : public JitCode {
int fy_idx = 13, int fy_idx = 13,
int mask_idx = 14, int mask_idx = 14,
int tmp_idx = 15) { int tmp_idx = 15) {
using namespace platform; // NOLINT using namespace phi; // NOLINT
// check all idx can not equal // check all idx can not equal
JMM jmm_src = JMM(src_idx); JMM jmm_src = JMM(src_idx);
JMM jmm_fx = JMM(fx_idx); JMM jmm_fx = JMM(fx_idx);
...@@ -266,7 +265,7 @@ class VActFunc : public JitCode { ...@@ -266,7 +265,7 @@ class VActFunc : public JitCode {
identity_jmm<JMM>(dst, src, 15); identity_jmm<JMM>(dst, src, 15);
break; break;
default: default:
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
"Do not support operand type code: %d.", type)); "Do not support operand type code: %d.", type));
break; break;
} }
...@@ -283,7 +282,7 @@ class VActJitCode : public VActFunc { ...@@ -283,7 +282,7 @@ class VActJitCode : public VActFunc {
if (!(type_ == operand_type::RELU || type_ == operand_type::EXP || if (!(type_ == operand_type::RELU || type_ == operand_type::EXP ||
type_ == operand_type::SIGMOID || type_ == operand_type::TANH || type_ == operand_type::SIGMOID || type_ == operand_type::TANH ||
type_ == operand_type::IDENTITY || type_ == operand_type::SQUARE)) { type_ == operand_type::IDENTITY || type_ == operand_type::SQUARE)) {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
"Do not support operand type code: %d.", type)); "Do not support operand type code: %d.", type));
} }
this->genCode(); this->genCode();
...@@ -348,5 +347,4 @@ DECLARE_ACT_JITCODE(VTanh, operand_type::TANH); ...@@ -348,5 +347,4 @@ DECLARE_ACT_JITCODE(VTanh, operand_type::TANH);
} // namespace gen } // namespace gen
} // namespace jit } // namespace jit
} // namespace operators } // namespace phi
} // namespace paddle
...@@ -12,15 +12,14 @@ ...@@ -12,15 +12,14 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. */ * limitations under the License. */
#include "paddle/fluid/operators/jit/gen/adam.h" #include "paddle/phi/kernels/funcs/jit/gen/adam.h"
#include <stddef.h> // offsetof #include <stddef.h> // offsetof
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/phi/backends/cpu/cpu_info.h" #include "paddle/phi/backends/cpu/cpu_info.h"
#include "paddle/phi/kernels/funcs/jit/registry.h"
namespace paddle { namespace phi {
namespace operators {
namespace jit { namespace jit {
namespace gen { namespace gen {
...@@ -145,9 +144,8 @@ class AdamCreator : public JitCodeCreator<adam_attr_t> { ...@@ -145,9 +144,8 @@ class AdamCreator : public JitCodeCreator<adam_attr_t> {
} // namespace gen } // namespace gen
} // namespace jit } // namespace jit
} // namespace operators } // namespace phi
} // namespace paddle
namespace gen = paddle::operators::jit::gen; namespace gen = phi::jit::gen;
REGISTER_JITKERNEL_GEN(kAdam, gen::AdamCreator); REGISTER_JITKERNEL_GEN(kAdam, gen::AdamCreator);
...@@ -17,11 +17,10 @@ ...@@ -17,11 +17,10 @@
#include <string> #include <string>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h" #include "paddle/phi/core/enforce.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/phi/kernels/funcs/jit/gen/jitcode.h"
namespace paddle { namespace phi {
namespace operators {
namespace jit { namespace jit {
namespace gen { namespace gen {
...@@ -72,5 +71,4 @@ class AdamJitCode : public JitCode { ...@@ -72,5 +71,4 @@ class AdamJitCode : public JitCode {
} // namespace gen } // namespace gen
} // namespace jit } // namespace jit
} // namespace operators } // namespace phi
} // namespace paddle
...@@ -12,15 +12,14 @@ ...@@ -12,15 +12,14 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. */ * limitations under the License. */
#include "paddle/fluid/operators/jit/gen/adamw.h" #include "paddle/phi/kernels/funcs/jit/gen/adamw.h"
#include <stddef.h> // offsetof #include <stddef.h> // offsetof
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/phi/backends/cpu/cpu_info.h" #include "paddle/phi/backends/cpu/cpu_info.h"
#include "paddle/phi/kernels/funcs/jit/registry.h"
namespace paddle { namespace phi {
namespace operators {
namespace jit { namespace jit {
namespace gen { namespace gen {
...@@ -157,9 +156,8 @@ class AdamWCreator : public JitCodeCreator<int> { ...@@ -157,9 +156,8 @@ class AdamWCreator : public JitCodeCreator<int> {
} // namespace gen } // namespace gen
} // namespace jit } // namespace jit
} // namespace operators } // namespace phi
} // namespace paddle
namespace gen = paddle::operators::jit::gen; namespace gen = phi::jit::gen;
REGISTER_JITKERNEL_GEN(kAdamW, gen::AdamWCreator); REGISTER_JITKERNEL_GEN(kAdamW, gen::AdamWCreator);
...@@ -17,11 +17,10 @@ ...@@ -17,11 +17,10 @@
#include <string> #include <string>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h" #include "paddle/phi/core/enforce.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/phi/kernels/funcs/jit/gen/jitcode.h"
namespace paddle { namespace phi {
namespace operators {
namespace jit { namespace jit {
namespace gen { namespace gen {
...@@ -78,5 +77,4 @@ class AdamWJitCode : public JitCode { ...@@ -78,5 +77,4 @@ class AdamWJitCode : public JitCode {
} // namespace gen } // namespace gen
} // namespace jit } // namespace jit
} // namespace operators } // namespace phi
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -12,14 +12,13 @@ ...@@ -12,14 +12,13 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. */ * limitations under the License. */
#include "paddle/fluid/operators/jit/gen/blas.h" #include "paddle/phi/kernels/funcs/jit/gen/blas.h"
#include "paddle/fluid/operators/jit/macro.h"
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/phi/backends/cpu/cpu_info.h" #include "paddle/phi/backends/cpu/cpu_info.h"
#include "paddle/phi/kernels/funcs/jit/macro.h"
#include "paddle/phi/kernels/funcs/jit/registry.h"
namespace paddle { namespace phi {
namespace operators {
namespace jit { namespace jit {
namespace gen { namespace gen {
...@@ -179,10 +178,9 @@ DECLARE_BLAS_CREATOR(VAddBias); ...@@ -179,10 +178,9 @@ DECLARE_BLAS_CREATOR(VAddBias);
} // namespace gen } // namespace gen
} // namespace jit } // namespace jit
} // namespace operators } // namespace phi
} // namespace paddle
namespace gen = paddle::operators::jit::gen; namespace gen = phi::jit::gen;
REGISTER_JITKERNEL_GEN(kVMul, gen::VMulCreator); REGISTER_JITKERNEL_GEN(kVMul, gen::VMulCreator);
REGISTER_JITKERNEL_GEN(kVAdd, gen::VAddCreator); REGISTER_JITKERNEL_GEN(kVAdd, gen::VAddCreator);
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -17,11 +17,10 @@ ...@@ -17,11 +17,10 @@
#include <string> #include <string>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h" #include "paddle/phi/core/enforce.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/phi/kernels/funcs/jit/gen/jitcode.h"
namespace paddle { namespace phi {
namespace operators {
namespace jit { namespace jit {
namespace gen { namespace gen {
...@@ -41,7 +40,7 @@ class VXXJitCode : public JitCode { ...@@ -41,7 +40,7 @@ class VXXJitCode : public JitCode {
with_relu_(with_relu) { with_relu_(with_relu) {
if (!(type_ == operand_type::MUL || type_ == operand_type::ADD || if (!(type_ == operand_type::MUL || type_ == operand_type::ADD ||
type_ == operand_type::SUB)) { type_ == operand_type::SUB)) {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
"Do not support operand type code: %d.", type)); "Do not support operand type code: %d.", type));
} }
this->genCode(); this->genCode();
...@@ -124,5 +123,4 @@ class NCHW16CMulNCJitCode : public JitCode { ...@@ -124,5 +123,4 @@ class NCHW16CMulNCJitCode : public JitCode {
} // namespace gen } // namespace gen
} // namespace jit } // namespace jit
} // namespace operators } // namespace phi
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -12,16 +12,15 @@ ...@@ -12,16 +12,15 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. */ * limitations under the License. */
#include "paddle/fluid/operators/jit/gen/embseqpool.h" #include "paddle/phi/kernels/funcs/jit/gen/embseqpool.h"
#include <stddef.h> // offsetof #include <stddef.h> // offsetof
#include "paddle/fluid/operators/jit/macro.h"
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/phi/backends/cpu/cpu_info.h" #include "paddle/phi/backends/cpu/cpu_info.h"
#include "paddle/phi/kernels/funcs/jit/macro.h"
#include "paddle/phi/kernels/funcs/jit/registry.h"
namespace paddle { namespace phi {
namespace operators {
namespace jit { namespace jit {
namespace gen { namespace gen {
...@@ -133,31 +132,31 @@ class EmbSeqPoolCreator : public JitCodeCreator<emb_seq_pool_attr_t> { ...@@ -133,31 +132,31 @@ class EmbSeqPoolCreator : public JitCodeCreator<emb_seq_pool_attr_t> {
const emb_seq_pool_attr_t& attr) const override { const emb_seq_pool_attr_t& attr) const override {
PADDLE_ENFORCE_GT(attr.table_height, PADDLE_ENFORCE_GT(attr.table_height,
0, 0,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The attribute table_height of EmbSeqPool should " "The attribute table_height of EmbSeqPool should "
"be larger than 0. But it is %d.", "be larger than 0. But it is %d.",
attr.table_height)); attr.table_height));
PADDLE_ENFORCE_GT(attr.table_width, PADDLE_ENFORCE_GT(attr.table_width,
0, 0,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The attribute table_width of EmbSeqPool should " "The attribute table_width of EmbSeqPool should "
"be larger than 0. But it is %d.", "be larger than 0. But it is %d.",
attr.table_width)); attr.table_width));
PADDLE_ENFORCE_GT(attr.index_height, PADDLE_ENFORCE_GT(attr.index_height,
0, 0,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The attribute index_height of EmbSeqPool should " "The attribute index_height of EmbSeqPool should "
"be larger than 0. But it is %d.", "be larger than 0. But it is %d.",
attr.index_height)); attr.index_height));
PADDLE_ENFORCE_GT(attr.index_width, PADDLE_ENFORCE_GT(attr.index_width,
0, 0,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The attribute index_width of EmbSeqPool should " "The attribute index_width of EmbSeqPool should "
"be larger than 0. But it is %d.", "be larger than 0. But it is %d.",
attr.index_width)); attr.index_width));
PADDLE_ENFORCE_GT(attr.out_width, PADDLE_ENFORCE_GT(attr.out_width,
0, 0,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The attribute out_width of EmbSeqPool should be " "The attribute out_width of EmbSeqPool should be "
"larger than 0. But it is %d.", "larger than 0. But it is %d.",
attr.out_width)); attr.out_width));
...@@ -167,9 +166,8 @@ class EmbSeqPoolCreator : public JitCodeCreator<emb_seq_pool_attr_t> { ...@@ -167,9 +166,8 @@ class EmbSeqPoolCreator : public JitCodeCreator<emb_seq_pool_attr_t> {
} // namespace gen } // namespace gen
} // namespace jit } // namespace jit
} // namespace operators } // namespace phi
} // namespace paddle
namespace gen = paddle::operators::jit::gen; namespace gen = phi::jit::gen;
REGISTER_JITKERNEL_GEN(kEmbSeqPool, gen::EmbSeqPoolCreator); REGISTER_JITKERNEL_GEN(kEmbSeqPool, gen::EmbSeqPoolCreator);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -17,11 +17,10 @@ ...@@ -17,11 +17,10 @@
#include <string> #include <string>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h" #include "paddle/phi/core/enforce.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/phi/kernels/funcs/jit/gen/jitcode.h"
namespace paddle { namespace phi {
namespace operators {
namespace jit { namespace jit {
namespace gen { namespace gen {
...@@ -34,8 +33,7 @@ class EmbSeqPoolJitCode : public JitCode { ...@@ -34,8 +33,7 @@ class EmbSeqPoolJitCode : public JitCode {
tbl_w_(attr.table_width), tbl_w_(attr.table_width),
type_(attr.pool_type) { type_(attr.pool_type) {
if (type_ != SeqPoolType::kSum) { if (type_ != SeqPoolType::kSum) {
PADDLE_THROW( PADDLE_THROW(phi::errors::Unimplemented("Only supports sum pool yet."));
platform::errors::Unimplemented("Only supports sum pool yet."));
} }
this->genCode(); this->genCode();
} }
...@@ -79,5 +77,4 @@ class EmbSeqPoolJitCode : public JitCode { ...@@ -79,5 +77,4 @@ class EmbSeqPoolJitCode : public JitCode {
} // namespace gen } // namespace gen
} // namespace jit } // namespace jit
} // namespace operators } // namespace phi
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -12,16 +12,15 @@ ...@@ -12,16 +12,15 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. */ * limitations under the License. */
#include "paddle/fluid/operators/jit/gen/gru.h" #include "paddle/phi/kernels/funcs/jit/gen/gru.h"
#include <stddef.h> // offsetof #include <stddef.h> // offsetof
#include "paddle/fluid/operators/jit/macro.h"
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/phi/backends/cpu/cpu_info.h" #include "paddle/phi/backends/cpu/cpu_info.h"
#include "paddle/phi/kernels/funcs/jit/macro.h"
#include "paddle/phi/kernels/funcs/jit/registry.h"
namespace paddle { namespace phi {
namespace operators {
namespace jit { namespace jit {
namespace gen { namespace gen {
...@@ -110,10 +109,9 @@ DECLARE_GRU_CREATOR(GRUHtPart2); ...@@ -110,10 +109,9 @@ DECLARE_GRU_CREATOR(GRUHtPart2);
} // namespace gen } // namespace gen
} // namespace jit } // namespace jit
} // namespace operators } // namespace phi
} // namespace paddle
namespace gen = paddle::operators::jit::gen; namespace gen = phi::jit::gen;
REGISTER_JITKERNEL_GEN(kGRUH1, gen::GRUH1Creator); REGISTER_JITKERNEL_GEN(kGRUH1, gen::GRUH1Creator);
REGISTER_JITKERNEL_GEN(kGRUHtPart1, gen::GRUHtPart1Creator); REGISTER_JITKERNEL_GEN(kGRUHtPart1, gen::GRUHtPart1Creator);
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -17,11 +17,10 @@ ...@@ -17,11 +17,10 @@
#include <string> #include <string>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/act.h" #include "paddle/phi/kernels/funcs/jit/gen/act.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h" #include "paddle/phi/kernels/funcs/jit/gen/jitcode.h"
namespace paddle { namespace phi {
namespace operators {
namespace jit { namespace jit {
namespace gen { namespace gen {
...@@ -42,7 +41,7 @@ class GRUJitCode : public VActFunc { ...@@ -42,7 +41,7 @@ class GRUJitCode : public VActFunc {
} else if (type == KernelType::kVIdentity) { } else if (type == KernelType::kVIdentity) {
return operand_type::IDENTITY; return operand_type::IDENTITY;
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
"Do not support jit::KernelType code: %d.", type)); "Do not support jit::KernelType code: %d.", type));
} }
return operand_type::IDENTITY; return operand_type::IDENTITY;
...@@ -114,5 +113,4 @@ DECLARE_GRU_JITCODE(GRUHtPart2, 2); ...@@ -114,5 +113,4 @@ DECLARE_GRU_JITCODE(GRUHtPart2, 2);
} // namespace gen } // namespace gen
} // namespace jit } // namespace jit
} // namespace operators } // namespace phi
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -12,13 +12,12 @@ ...@@ -12,13 +12,12 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. */ * limitations under the License. */
#include "paddle/fluid/operators/jit/gen/hopv.h" #include "paddle/phi/kernels/funcs/jit/gen/hopv.h"
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/phi/backends/cpu/cpu_info.h" #include "paddle/phi/backends/cpu/cpu_info.h"
#include "paddle/phi/kernels/funcs/jit/registry.h"
namespace paddle { namespace phi {
namespace operators {
namespace jit { namespace jit {
namespace gen { namespace gen {
...@@ -95,10 +94,9 @@ DECLARE_HOP_CREATOR(HSum); ...@@ -95,10 +94,9 @@ DECLARE_HOP_CREATOR(HSum);
} // namespace gen } // namespace gen
} // namespace jit } // namespace jit
} // namespace operators } // namespace phi
} // namespace paddle
namespace gen = paddle::operators::jit::gen; namespace gen = phi::jit::gen;
REGISTER_JITKERNEL_GEN(kHMax, gen::HMaxCreator); REGISTER_JITKERNEL_GEN(kHMax, gen::HMaxCreator);
REGISTER_JITKERNEL_GEN(kHSum, gen::HSumCreator); REGISTER_JITKERNEL_GEN(kHSum, gen::HSumCreator);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -17,11 +17,10 @@ ...@@ -17,11 +17,10 @@
#include <string> #include <string>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h" #include "paddle/phi/core/enforce.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/phi/kernels/funcs/jit/gen/jitcode.h"
namespace paddle { namespace phi {
namespace operators {
namespace jit { namespace jit {
namespace gen { namespace gen {
...@@ -34,7 +33,7 @@ class HOPVJitCode : public JitCode { ...@@ -34,7 +33,7 @@ class HOPVJitCode : public JitCode {
void* code_ptr = nullptr) void* code_ptr = nullptr)
: JitCode(code_size, code_ptr), num_(d), type_(type) { : JitCode(code_size, code_ptr), num_(d), type_(type) {
if (!(type_ == operand_type::MAX || type_ == operand_type::ADD)) { if (!(type_ == operand_type::MAX || type_ == operand_type::ADD)) {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
"Do not support operand type code: %d.", type)); "Do not support operand type code: %d.", type));
} }
this->genCode(); this->genCode();
...@@ -91,5 +90,4 @@ DECLARE_HOP_JITCODE(HSum, operand_type::ADD); ...@@ -91,5 +90,4 @@ DECLARE_HOP_JITCODE(HSum, operand_type::ADD);
} // namespace gen } // namespace gen
} // namespace jit } // namespace jit
} // namespace operators } // namespace phi
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -17,15 +17,14 @@ ...@@ -17,15 +17,14 @@
#include <string> #include <string>
#include <type_traits> #include <type_traits>
#include "paddle/fluid/operators/jit/gen_base.h"
#include "paddle/phi/backends/cpu/cpu_info.h" #include "paddle/phi/backends/cpu/cpu_info.h"
#include "paddle/phi/kernels/funcs/jit/gen_base.h"
#define XBYAK_USE_MMAP_ALLOCATOR #define XBYAK_USE_MMAP_ALLOCATOR
#include "xbyak/xbyak.h" #include "xbyak/xbyak.h"
#include "xbyak/xbyak_util.h" #include "xbyak/xbyak_util.h"
namespace paddle { namespace phi {
namespace operators {
namespace jit { namespace jit {
namespace gen { namespace gen {
...@@ -131,5 +130,4 @@ class JitCode : public GenBase, public Xbyak::CodeGenerator { ...@@ -131,5 +130,4 @@ class JitCode : public GenBase, public Xbyak::CodeGenerator {
} // namespace gen } // namespace gen
} // namespace jit } // namespace jit
} // namespace operators } // namespace phi
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -12,16 +12,15 @@ ...@@ -12,16 +12,15 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. */ * limitations under the License. */
#include "paddle/fluid/operators/jit/gen/lstm.h" #include "paddle/phi/kernels/funcs/jit/gen/lstm.h"
#include <stddef.h> // offsetof #include <stddef.h> // offsetof
#include "paddle/fluid/operators/jit/macro.h"
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/phi/backends/cpu/cpu_info.h" #include "paddle/phi/backends/cpu/cpu_info.h"
#include "paddle/phi/kernels/funcs/jit/macro.h"
#include "paddle/phi/kernels/funcs/jit/registry.h"
namespace paddle { namespace phi {
namespace operators {
namespace jit { namespace jit {
namespace gen { namespace gen {
...@@ -137,10 +136,9 @@ DECLARE_LSTM_CREATOR(LSTMC1H1); ...@@ -137,10 +136,9 @@ DECLARE_LSTM_CREATOR(LSTMC1H1);
} // namespace gen } // namespace gen
} // namespace jit } // namespace jit
} // namespace operators } // namespace phi
} // namespace paddle
namespace gen = paddle::operators::jit::gen; namespace gen = phi::jit::gen;
REGISTER_JITKERNEL_GEN(kLSTMCtHt, gen::LSTMCtHtCreator); REGISTER_JITKERNEL_GEN(kLSTMCtHt, gen::LSTMCtHtCreator);
REGISTER_JITKERNEL_GEN(kLSTMC1H1, gen::LSTMC1H1Creator); REGISTER_JITKERNEL_GEN(kLSTMC1H1, gen::LSTMC1H1Creator);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -17,11 +17,10 @@ ...@@ -17,11 +17,10 @@
#include <string> #include <string>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/act.h" #include "paddle/phi/kernels/funcs/jit/gen/act.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h" #include "paddle/phi/kernels/funcs/jit/gen/jitcode.h"
namespace paddle { namespace phi {
namespace operators {
namespace jit { namespace jit {
namespace gen { namespace gen {
...@@ -45,7 +44,7 @@ class LSTMJitCode : public VActFunc { ...@@ -45,7 +44,7 @@ class LSTMJitCode : public VActFunc {
} else if (type == KernelType::kVIdentity) { } else if (type == KernelType::kVIdentity) {
return operand_type::IDENTITY; return operand_type::IDENTITY;
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
"Do not support jit::KernelType code: %d.", type)); "Do not support jit::KernelType code: %d.", type));
} }
return operand_type::IDENTITY; return operand_type::IDENTITY;
...@@ -119,5 +118,4 @@ DECLARE_LSTM_JITCODE(LSTMC1H1, true); ...@@ -119,5 +118,4 @@ DECLARE_LSTM_JITCODE(LSTMC1H1, true);
} // namespace gen } // namespace gen
} // namespace jit } // namespace jit
} // namespace operators } // namespace phi
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -12,15 +12,14 @@ ...@@ -12,15 +12,14 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. */ * limitations under the License. */
#include "paddle/fluid/operators/jit/gen/matmul.h" #include "paddle/phi/kernels/funcs/jit/gen/matmul.h"
#include <stddef.h> // offsetof #include <stddef.h> // offsetof
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/phi/backends/cpu/cpu_info.h" #include "paddle/phi/backends/cpu/cpu_info.h"
#include "paddle/phi/kernels/funcs/jit/registry.h"
namespace paddle { namespace phi {
namespace operators {
namespace jit { namespace jit {
namespace gen { namespace gen {
...@@ -31,9 +30,9 @@ void MatMulJitCode::genCode() { ...@@ -31,9 +30,9 @@ void MatMulJitCode::genCode() {
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
groups.front(), groups.front(),
0, 0,
platform::errors::InvalidArgument("The number of rest registers should " phi::errors::InvalidArgument("The number of rest registers should "
"be larger than 0. But it is %d.", "be larger than 0. But it is %d.",
groups.front())); groups.front()));
const int block_len = sizeof(float) * block; const int block_len = sizeof(float) * block;
const int x_reg_idx = (block == ZMM_FLOAT_BLOCK ? 32 : 16) - 1; const int x_reg_idx = (block == ZMM_FLOAT_BLOCK ? 32 : 16) - 1;
...@@ -126,21 +125,21 @@ class MatMulCreator : public JitCodeCreator<matmul_attr_t> { ...@@ -126,21 +125,21 @@ class MatMulCreator : public JitCodeCreator<matmul_attr_t> {
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
attr.m, attr.m,
0, 0,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The attribute m (first matrix's row) of MatMul should " "The attribute m (first matrix's row) of MatMul should "
"be larger than 0. But it is %d.", "be larger than 0. But it is %d.",
attr.m)); attr.m));
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
attr.n, attr.n,
0, 0,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The attribute n (first matrix's col) of MatMul should " "The attribute n (first matrix's col) of MatMul should "
"be larger than 0. But it is %d.", "be larger than 0. But it is %d.",
attr.n)); attr.n));
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
attr.k, attr.k,
0, 0,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The attribute k (second matrix's col) of MatMul should " "The attribute k (second matrix's col) of MatMul should "
"be larger than 0. But it is %d.", "be larger than 0. But it is %d.",
attr.k)); attr.k));
...@@ -150,9 +149,8 @@ class MatMulCreator : public JitCodeCreator<matmul_attr_t> { ...@@ -150,9 +149,8 @@ class MatMulCreator : public JitCodeCreator<matmul_attr_t> {
} // namespace gen } // namespace gen
} // namespace jit } // namespace jit
} // namespace operators } // namespace phi
} // namespace paddle
namespace gen = paddle::operators::jit::gen; namespace gen = phi::jit::gen;
REGISTER_JITKERNEL_GEN(kMatMul, gen::MatMulCreator); REGISTER_JITKERNEL_GEN(kMatMul, gen::MatMulCreator);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -20,11 +20,10 @@ ...@@ -20,11 +20,10 @@
#include <vector> #include <vector>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h" #include "paddle/phi/core/enforce.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/phi/kernels/funcs/jit/gen/jitcode.h"
namespace paddle { namespace phi {
namespace operators {
namespace jit { namespace jit {
namespace gen { namespace gen {
...@@ -34,12 +33,12 @@ class MatMulJitCode : public JitCode { ...@@ -34,12 +33,12 @@ class MatMulJitCode : public JitCode {
size_t code_size = 256 * 1024, size_t code_size = 256 * 1024,
void* code_ptr = nullptr) void* code_ptr = nullptr)
: JitCode(code_size, code_ptr), m_(attr.m), n_(attr.n), k_(attr.k) { : JitCode(code_size, code_ptr), m_(attr.m), n_(attr.n), k_(attr.k) {
PADDLE_ENFORCE_EQ(m_, PADDLE_ENFORCE_EQ(
1, m_,
platform::errors::Unimplemented( 1,
"Jitcode of matmul only support m==1 (first " phi::errors::Unimplemented("Jitcode of matmul only support m==1 (first "
"matrix's row) now. But m is %d.", "matrix's row) now. But m is %d.",
m_)); m_));
this->genCode(); this->genCode();
} }
...@@ -65,5 +64,4 @@ class MatMulJitCode : public JitCode { ...@@ -65,5 +64,4 @@ class MatMulJitCode : public JitCode {
} // namespace gen } // namespace gen
} // namespace jit } // namespace jit
} // namespace operators } // namespace phi
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -12,14 +12,13 @@ ...@@ -12,14 +12,13 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. */ * limitations under the License. */
#include "paddle/fluid/operators/jit/gen/seqpool.h" #include "paddle/phi/kernels/funcs/jit/gen/seqpool.h"
#include "paddle/fluid/operators/jit/gen/act.h" // for exp_float_consts ones
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/phi/backends/cpu/cpu_info.h" #include "paddle/phi/backends/cpu/cpu_info.h"
#include "paddle/phi/kernels/funcs/jit/gen/act.h" // for exp_float_consts ones
#include "paddle/phi/kernels/funcs/jit/registry.h"
namespace paddle { namespace phi {
namespace operators {
namespace jit { namespace jit {
namespace gen { namespace gen {
...@@ -69,27 +68,26 @@ class SeqPoolCreator : public JitCodeCreator<seq_pool_attr_t> { ...@@ -69,27 +68,26 @@ class SeqPoolCreator : public JitCodeCreator<seq_pool_attr_t> {
} }
std::unique_ptr<GenBase> CreateJitCode( std::unique_ptr<GenBase> CreateJitCode(
const seq_pool_attr_t& attr) const override { const seq_pool_attr_t& attr) const override {
PADDLE_ENFORCE_GT(attr.w, PADDLE_ENFORCE_GT(
0, attr.w,
platform::errors::InvalidArgument( 0,
"The attribute width of SeqPool should " phi::errors::InvalidArgument("The attribute width of SeqPool should "
"be larger than 0. But it is %d.", "be larger than 0. But it is %d.",
attr.w)); attr.w));
PADDLE_ENFORCE_GT(attr.h, PADDLE_ENFORCE_GT(
0, attr.h,
platform::errors::InvalidArgument( 0,
"The attribute height of SeqPool should " phi::errors::InvalidArgument("The attribute height of SeqPool should "
"be larger than 0. But it is %d.", "be larger than 0. But it is %d.",
attr.h)); attr.h));
return make_unique<SeqPoolJitCode>(attr, CodeSize(attr)); return make_unique<SeqPoolJitCode>(attr, CodeSize(attr));
} }
}; };
} // namespace gen } // namespace gen
} // namespace jit } // namespace jit
} // namespace operators } // namespace phi
} // namespace paddle
namespace gen = paddle::operators::jit::gen; namespace gen = phi::jit::gen;
REGISTER_JITKERNEL_GEN(kSeqPool, gen::SeqPoolCreator); REGISTER_JITKERNEL_GEN(kSeqPool, gen::SeqPoolCreator);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -17,11 +17,10 @@ ...@@ -17,11 +17,10 @@
#include <string> #include <string>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h" #include "paddle/phi/core/enforce.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/phi/kernels/funcs/jit/gen/jitcode.h"
namespace paddle { namespace phi {
namespace operators {
namespace jit { namespace jit {
namespace gen { namespace gen {
...@@ -33,7 +32,7 @@ class SeqPoolJitCode : public JitCode { ...@@ -33,7 +32,7 @@ class SeqPoolJitCode : public JitCode {
: JitCode(code_size, code_ptr), w_(attr.w), type_(attr.type) { : JitCode(code_size, code_ptr), w_(attr.w), type_(attr.type) {
if (!(type_ == SeqPoolType::kSum || type_ == SeqPoolType::kAvg || if (!(type_ == SeqPoolType::kSum || type_ == SeqPoolType::kAvg ||
type_ == SeqPoolType::kSqrt)) { type_ == SeqPoolType::kSqrt)) {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
"Only supports sum, average and sqrt pool type.")); "Only supports sum, average and sqrt pool type."));
} }
fp_h_[0] = 1.f; fp_h_[0] = 1.f;
...@@ -130,7 +129,7 @@ class SeqPoolJitCode : public JitCode { ...@@ -130,7 +129,7 @@ class SeqPoolJitCode : public JitCode {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
reg_idx, reg_idx,
rest_used_num_regs, rest_used_num_regs,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"All heights of SeqPool should use the same number of registers." "All heights of SeqPool should use the same number of registers."
"It equals to the numbr of rest registers. But use %d registers " "It equals to the numbr of rest registers. But use %d registers "
"and the numbr of rest registers is %d.", "and the numbr of rest registers is %d.",
...@@ -221,5 +220,4 @@ class SeqPoolJitCode : public JitCode { ...@@ -221,5 +220,4 @@ class SeqPoolJitCode : public JitCode {
} // namespace gen } // namespace gen
} // namespace jit } // namespace jit
} // namespace operators } // namespace phi
} // namespace paddle
...@@ -12,15 +12,14 @@ ...@@ -12,15 +12,14 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. */ * limitations under the License. */
#include "paddle/fluid/operators/jit/gen/sgd.h" #include "paddle/phi/kernels/funcs/jit/gen/sgd.h"
#include <stddef.h> // offsetof #include <stddef.h> // offsetof
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/phi/backends/cpu/cpu_info.h" #include "paddle/phi/backends/cpu/cpu_info.h"
#include "paddle/phi/kernels/funcs/jit/registry.h"
namespace paddle { namespace phi {
namespace operators {
namespace jit { namespace jit {
namespace gen { namespace gen {
...@@ -117,7 +116,7 @@ class SgdCreator : public JitCodeCreator<sgd_attr_t> { ...@@ -117,7 +116,7 @@ class SgdCreator : public JitCodeCreator<sgd_attr_t> {
const sgd_attr_t& attr) const override { const sgd_attr_t& attr) const override {
PADDLE_ENFORCE_EQ(attr.param_width, PADDLE_ENFORCE_EQ(attr.param_width,
attr.grad_width, attr.grad_width,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The attribute param_width of Sgd should be " "The attribute param_width of Sgd should be "
"equal to the attribute grad_width. But param_width " "equal to the attribute grad_width. But param_width "
"is %d and grad_width is %d.", "is %d and grad_width is %d.",
...@@ -125,7 +124,7 @@ class SgdCreator : public JitCodeCreator<sgd_attr_t> { ...@@ -125,7 +124,7 @@ class SgdCreator : public JitCodeCreator<sgd_attr_t> {
attr.grad_width)); attr.grad_width));
PADDLE_ENFORCE_LE(attr.selected_rows_size, PADDLE_ENFORCE_LE(attr.selected_rows_size,
attr.grad_height, attr.grad_height,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The attribute selected_rows_size of Sgd should be " "The attribute selected_rows_size of Sgd should be "
"equal to or less than the attribute grad_height. " "equal to or less than the attribute grad_height. "
"But selected_rows_size is %d and grad_height is %d.", "But selected_rows_size is %d and grad_height is %d.",
...@@ -134,7 +133,7 @@ class SgdCreator : public JitCodeCreator<sgd_attr_t> { ...@@ -134,7 +133,7 @@ class SgdCreator : public JitCodeCreator<sgd_attr_t> {
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
attr.selected_rows_size, attr.selected_rows_size,
0, 0,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The attribute selected_rows_size of Sgd should be " "The attribute selected_rows_size of Sgd should be "
"equal to or larger than 0. But selected_rows_size is %d.", "equal to or larger than 0. But selected_rows_size is %d.",
attr.selected_rows_size)); attr.selected_rows_size));
...@@ -144,9 +143,8 @@ class SgdCreator : public JitCodeCreator<sgd_attr_t> { ...@@ -144,9 +143,8 @@ class SgdCreator : public JitCodeCreator<sgd_attr_t> {
} // namespace gen } // namespace gen
} // namespace jit } // namespace jit
} // namespace operators } // namespace phi
} // namespace paddle
namespace gen = paddle::operators::jit::gen; namespace gen = phi::jit::gen;
REGISTER_JITKERNEL_GEN(kSgd, gen::SgdCreator); REGISTER_JITKERNEL_GEN(kSgd, gen::SgdCreator);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -17,11 +17,10 @@ ...@@ -17,11 +17,10 @@
#include <string> #include <string>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h" #include "paddle/phi/core/enforce.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/phi/kernels/funcs/jit/gen/jitcode.h"
namespace paddle { namespace phi {
namespace operators {
namespace jit { namespace jit {
namespace gen { namespace gen {
...@@ -59,5 +58,4 @@ class SgdJitCode : public JitCode { ...@@ -59,5 +58,4 @@ class SgdJitCode : public JitCode {
} // namespace gen } // namespace gen
} // namespace jit } // namespace jit
} // namespace operators } // namespace phi
} // namespace paddle
...@@ -12,13 +12,12 @@ ...@@ -12,13 +12,12 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. */ * limitations under the License. */
#include "paddle/fluid/operators/jit/gen/vbroadcast.h" #include "paddle/phi/kernels/funcs/jit/gen/vbroadcast.h"
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/phi/backends/cpu/cpu_info.h" #include "paddle/phi/backends/cpu/cpu_info.h"
#include "paddle/phi/kernels/funcs/jit/registry.h"
namespace paddle { namespace phi {
namespace operators {
namespace jit { namespace jit {
namespace gen { namespace gen {
...@@ -79,7 +78,7 @@ class VBroadcastCreator : public JitCodeCreator<int64_t> { ...@@ -79,7 +78,7 @@ class VBroadcastCreator : public JitCodeCreator<int64_t> {
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
w, w,
0, 0,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The width of VBroadcast should be larger than 0. But w is %d.", "The width of VBroadcast should be larger than 0. But w is %d.",
w)); w));
return make_unique<VBroadcastJitCode>(w, CodeSize(w)); return make_unique<VBroadcastJitCode>(w, CodeSize(w));
...@@ -88,9 +87,8 @@ class VBroadcastCreator : public JitCodeCreator<int64_t> { ...@@ -88,9 +87,8 @@ class VBroadcastCreator : public JitCodeCreator<int64_t> {
} // namespace gen } // namespace gen
} // namespace jit } // namespace jit
} // namespace operators } // namespace phi
} // namespace paddle
namespace gen = paddle::operators::jit::gen; namespace gen = phi::jit::gen;
REGISTER_JITKERNEL_GEN(kVBroadcast, gen::VBroadcastCreator); REGISTER_JITKERNEL_GEN(kVBroadcast, gen::VBroadcastCreator);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -17,10 +17,10 @@ ...@@ -17,10 +17,10 @@
#include <string> #include <string>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h" #include "paddle/phi/core/enforce.h"
#include "paddle/phi/kernels/funcs/jit/gen/jitcode.h"
namespace paddle { namespace phi {
namespace operators {
namespace jit { namespace jit {
namespace gen { namespace gen {
...@@ -51,5 +51,4 @@ class VBroadcastJitCode : public JitCode { ...@@ -51,5 +51,4 @@ class VBroadcastJitCode : public JitCode {
} // namespace gen } // namespace gen
} // namespace jit } // namespace jit
} // namespace operators } // namespace phi
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -12,22 +12,22 @@ ...@@ -12,22 +12,22 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. */ * limitations under the License. */
#include "paddle/fluid/operators/jit/gen_base.h" #include "paddle/phi/kernels/funcs/jit/gen_base.h"
#include <fstream> #include <fstream>
#include "paddle/fluid/memory/allocation/cpu_allocator.h" // for posix_memalign
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/backends/cpu/cpu_info.h" #include "paddle/phi/backends/cpu/cpu_info.h"
#include "paddle/phi/core/enforce.h"
#ifndef _WIN32 #ifdef _WIN32
#define posix_memalign_free _aligned_free
#else
#define posix_memalign_free free #define posix_memalign_free free
#endif #endif
DEFINE_bool(dump_jitcode, false, "Whether to dump the jitcode to file"); DEFINE_bool(dump_jitcode, false, "Whether to dump the jitcode to file");
namespace paddle { namespace phi {
namespace operators {
namespace jit { namespace jit {
// refer do not need CanBeUsed, it would be the last one. // refer do not need CanBeUsed, it would be the last one.
...@@ -48,16 +48,20 @@ void GenBase::dumpCode(const unsigned char* code) const { ...@@ -48,16 +48,20 @@ void GenBase::dumpCode(const unsigned char* code) const {
void* GenBase::operator new(size_t size) { void* GenBase::operator new(size_t size) {
void* ptr; void* ptr;
constexpr size_t alignment = 32ul; constexpr size_t alignment = 32ul;
#ifdef _WIN32
ptr = _aligned_malloc(size, alignment);
#else
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
posix_memalign(&ptr, alignment, size), posix_memalign(&ptr, alignment, size),
0, 0,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"Jitcode generator (GenBase) allocate %ld memory error!", size)); "Jitcode generator (GenBase) allocate %ld memory error!", size));
#endif
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
ptr, ptr,
platform::errors::InvalidArgument("Fail to allocate jitcode generator " phi::errors::InvalidArgument("Fail to allocate jitcode generator "
"(GenBase) CPU memory: size = %d .", "(GenBase) CPU memory: size = %d .",
size)); size));
return ptr; return ptr;
} }
...@@ -93,5 +97,4 @@ std::vector<int> packed_groups(int n, int k, int* block_out, int* rest_out) { ...@@ -93,5 +97,4 @@ std::vector<int> packed_groups(int n, int k, int* block_out, int* rest_out) {
} }
} // namespace jit } // namespace jit
} // namespace operators } // namespace phi
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -18,13 +18,16 @@ ...@@ -18,13 +18,16 @@
#include <string> #include <string>
#include <vector> #include <vector>
#ifdef _WIN32
#include <malloc.h> // for _aligned_malloc
#endif
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "paddle/fluid/operators/jit/kernel_base.h" #include "paddle/phi/kernels/funcs/jit/kernel_base.h"
DECLARE_bool(dump_jitcode); DECLARE_bool(dump_jitcode);
namespace paddle { namespace phi {
namespace operators {
namespace jit { namespace jit {
class GenBase : public Kernel { class GenBase : public Kernel {
...@@ -84,5 +87,4 @@ std::vector<int> packed_groups(int n, ...@@ -84,5 +87,4 @@ std::vector<int> packed_groups(int n,
int* rest = nullptr); int* rest = nullptr);
} // namespace jit } // namespace jit
} // namespace operators } // namespace phi
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -12,14 +12,13 @@ ...@@ -12,14 +12,13 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. */ * limitations under the License. */
#include "paddle/fluid/operators/jit/helper.h" #include "paddle/phi/kernels/funcs/jit/helper.h"
#include <numeric> #include <numeric>
#include "paddle/fluid/platform/enforce.h" #include "paddle/phi/core/enforce.h"
namespace paddle { namespace phi {
namespace operators {
namespace jit { namespace jit {
std::map<size_t, std::shared_ptr<void>>& GetFuncCacheMap() { std::map<size_t, std::shared_ptr<void>>& GetFuncCacheMap() {
...@@ -68,7 +67,7 @@ const char* to_string(KernelType kt) { ...@@ -68,7 +67,7 @@ const char* to_string(KernelType kt) {
ONE_CASE(kEmbSeqPool); ONE_CASE(kEmbSeqPool);
ONE_CASE(kSgd); ONE_CASE(kSgd);
default: default:
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
"JIT kernel do not support type: %d.", kt)); "JIT kernel do not support type: %d.", kt));
return "NOT JITKernel"; return "NOT JITKernel";
} }
...@@ -82,7 +81,7 @@ const char* to_string(SeqPoolType tp) { ...@@ -82,7 +81,7 @@ const char* to_string(SeqPoolType tp) {
ONE_CASE(kAvg); ONE_CASE(kAvg);
ONE_CASE(kSqrt); ONE_CASE(kSqrt);
default: default:
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
"SeqPool JIT kernel do not support type: %d.", tp)); "SeqPool JIT kernel do not support type: %d.", tp));
return "NOT PoolType"; return "NOT PoolType";
} }
...@@ -104,7 +103,7 @@ KernelType to_kerneltype(const std::string& act) { ...@@ -104,7 +103,7 @@ KernelType to_kerneltype(const std::string& act) {
} else if (lower == "tanh" || lower == "vtanh") { } else if (lower == "tanh" || lower == "vtanh") {
return kVTanh; return kVTanh;
} }
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
"Act JIT kernel do not support type: %s.", act)); "Act JIT kernel do not support type: %s.", act));
return kNone; return kNone;
} }
...@@ -116,7 +115,7 @@ void pack_weights<float>(const float* src, float* dst, int n, int k) { ...@@ -116,7 +115,7 @@ void pack_weights<float>(const float* src, float* dst, int n, int k) {
std::for_each(groups.begin(), groups.end(), [&](int i) { std::for_each(groups.begin(), groups.end(), [&](int i) {
PADDLE_ENFORCE_GT(i, PADDLE_ENFORCE_GT(i,
0, 0,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"Each element of groups should be larger than " "Each element of groups should be larger than "
"0. However the element: %d doesn't satify.", "0. However the element: %d doesn't satify.",
i)); i));
...@@ -125,7 +124,7 @@ void pack_weights<float>(const float* src, float* dst, int n, int k) { ...@@ -125,7 +124,7 @@ void pack_weights<float>(const float* src, float* dst, int n, int k) {
std::memset(dst, 0, k * sum * block * sizeof(float)); std::memset(dst, 0, k * sum * block * sizeof(float));
PADDLE_ENFORCE_GE(sum * block, PADDLE_ENFORCE_GE(sum * block,
n, n,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The packed n (sum * block) should be equal to or " "The packed n (sum * block) should be equal to or "
"larger than n (matmul row size). " "larger than n (matmul row size). "
"However, the packed n is %d and n is %d.", "However, the packed n is %d and n is %d.",
...@@ -152,10 +151,9 @@ void pack_weights<float>(const float* src, float* dst, int n, int k) { ...@@ -152,10 +151,9 @@ void pack_weights<float>(const float* src, float* dst, int n, int k) {
template <typename T> template <typename T>
typename std::enable_if<!std::is_same<T, float>::value>::type pack_weights( typename std::enable_if<!std::is_same<T, float>::value>::type pack_weights(
const T* src, T* dst, int n, int k) { const T* src, T* dst, int n, int k) {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
"Only supports pack weights with float type.")); "Only supports pack weights with float type."));
} }
} // namespace jit } // namespace jit
} // namespace operators } // namespace phi
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -14,22 +14,22 @@ ...@@ -14,22 +14,22 @@
#pragma once #pragma once
#include <cstring>
#include <iostream> #include <iostream>
#include <map> #include <map>
#include <memory> #include <memory>
#include <string>
#include <unordered_map> #include <unordered_map>
#include <utility> // for std::move #include <utility> // for std::move
#include <vector> #include <vector>
#include "paddle/fluid/operators/jit/gen_base.h" #include "paddle/phi/common/place.h"
#include "paddle/fluid/operators/jit/kernel_base.h" #include "paddle/phi/core/enforce.h"
#include "paddle/fluid/operators/jit/kernel_key.h" #include "paddle/phi/kernels/funcs/jit/gen_base.h"
#include "paddle/fluid/operators/jit/kernel_pool.h" #include "paddle/phi/kernels/funcs/jit/kernel_base.h"
#include "paddle/fluid/platform/place.h" #include "paddle/phi/kernels/funcs/jit/kernel_key.h"
#include "paddle/phi/kernels/funcs/jit/kernel_pool.h"
namespace paddle { namespace phi {
namespace operators {
namespace jit { namespace jit {
class GenBase; class GenBase;
...@@ -37,7 +37,7 @@ class GenBase; ...@@ -37,7 +37,7 @@ class GenBase;
template <typename KernelTuple, typename PlaceType> template <typename KernelTuple, typename PlaceType>
inline typename std::enable_if< inline typename std::enable_if<
std::is_same<typename KernelTuple::data_type, float>::value && std::is_same<typename KernelTuple::data_type, float>::value &&
std::is_same<PlaceType, platform::CPUPlace>::value, std::is_same<PlaceType, phi::CPUPlace>::value,
const Kernel*>::type const Kernel*>::type
GetJitCode(const typename KernelTuple::attr_type& attr) { GetJitCode(const typename KernelTuple::attr_type& attr) {
using Attr = typename KernelTuple::attr_type; using Attr = typename KernelTuple::attr_type;
...@@ -72,7 +72,7 @@ GetJitCode(const typename KernelTuple::attr_type& attr) { ...@@ -72,7 +72,7 @@ GetJitCode(const typename KernelTuple::attr_type& attr) {
template <typename KernelTuple, typename PlaceType> template <typename KernelTuple, typename PlaceType>
inline typename std::enable_if< inline typename std::enable_if<
!std::is_same<typename KernelTuple::data_type, float>::value || !std::is_same<typename KernelTuple::data_type, float>::value ||
!std::is_same<PlaceType, platform::CPUPlace>::value, !std::is_same<PlaceType, phi::CPUPlace>::value,
const Kernel*>::type const Kernel*>::type
GetJitCode(const typename KernelTuple::attr_type& attr) { GetJitCode(const typename KernelTuple::attr_type& attr) {
return nullptr; return nullptr;
...@@ -83,12 +83,12 @@ GetJitCode(const typename KernelTuple::attr_type& attr) { ...@@ -83,12 +83,12 @@ GetJitCode(const typename KernelTuple::attr_type& attr) {
template <typename KernelTuple> template <typename KernelTuple>
inline const Kernel* GetReferKernel() { inline const Kernel* GetReferKernel() {
auto& ref_pool = ReferKernelPool::Instance().AllKernels(); auto& ref_pool = ReferKernelPool::Instance().AllKernels();
KernelKey kkey(KernelTuple::kernel_type, platform::CPUPlace()); KernelKey kkey(KernelTuple::kernel_type, phi::CPUPlace());
auto ref_iter = ref_pool.find(kkey); auto ref_iter = ref_pool.find(kkey);
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
ref_iter, ref_iter,
ref_pool.end(), ref_pool.end(),
platform::errors::PreconditionNotMet( phi::errors::PreconditionNotMet(
"Every Refer Kernel of jitcode should have reference function.")); "Every Refer Kernel of jitcode should have reference function."));
auto& ref_impls = ref_iter->second; auto& ref_impls = ref_iter->second;
for (auto& impl : ref_impls) { for (auto& impl : ref_impls) {
...@@ -104,10 +104,10 @@ template <typename KernelTuple> ...@@ -104,10 +104,10 @@ template <typename KernelTuple>
inline typename KernelTuple::func_type GetReferFunc() { inline typename KernelTuple::func_type GetReferFunc() {
auto ker = GetReferKernel<KernelTuple>(); auto ker = GetReferKernel<KernelTuple>();
auto p = dynamic_cast<const ReferKernel<KernelTuple>*>(ker); auto p = dynamic_cast<const ReferKernel<KernelTuple>*>(ker);
PADDLE_ENFORCE_NOT_NULL(p, PADDLE_ENFORCE_NOT_NULL(
platform::errors::InvalidArgument( p,
"Get the reference code of kernel in CPU " phi::errors::InvalidArgument("Get the reference code of kernel in CPU "
"failed. The Refer kernel should exsit.")); "failed. The Refer kernel should exsit."));
return p->GetFunc(); return p->GetFunc();
} }
...@@ -138,15 +138,15 @@ std::vector<const Kernel*> GetAllCandidateKernels( ...@@ -138,15 +138,15 @@ std::vector<const Kernel*> GetAllCandidateKernels(
// The last implementation should be reference function on CPUPlace. // The last implementation should be reference function on CPUPlace.
auto ref = GetReferKernel<KernelTuple>(); auto ref = GetReferKernel<KernelTuple>();
PADDLE_ENFORCE_NOT_NULL(ref, PADDLE_ENFORCE_NOT_NULL(
platform::errors::InvalidArgument( ref,
"Get all candicate kernel in CPU failed. " phi::errors::InvalidArgument("Get all candicate kernel in CPU failed. "
"The Refer Kernel can not be empty.")); "The Refer Kernel can not be empty."));
res.emplace_back(ref); res.emplace_back(ref);
return res; return res;
} }
template <typename KernelTuple, typename PlaceType = platform::CPUPlace> template <typename KernelTuple, typename PlaceType = phi::CPUPlace>
std::vector<std::pair<std::string, typename KernelTuple::func_type>> std::vector<std::pair<std::string, typename KernelTuple::func_type>>
GetAllCandidateFuncsWithTypes(const typename KernelTuple::attr_type& attr) { GetAllCandidateFuncsWithTypes(const typename KernelTuple::attr_type& attr) {
using Func = typename KernelTuple::func_type; using Func = typename KernelTuple::func_type;
...@@ -157,21 +157,20 @@ GetAllCandidateFuncsWithTypes(const typename KernelTuple::attr_type& attr) { ...@@ -157,21 +157,20 @@ GetAllCandidateFuncsWithTypes(const typename KernelTuple::attr_type& attr) {
if (name == "JitCode") { if (name == "JitCode") {
auto i = dynamic_cast<const GenBase*>(k); auto i = dynamic_cast<const GenBase*>(k);
PADDLE_ENFORCE_NOT_NULL(i, PADDLE_ENFORCE_NOT_NULL(i,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"Generate jitcode kernel (GenBase) failed.")); "Generate jitcode kernel (GenBase) failed."));
res.emplace_back(std::make_pair(name, i->template getCode<Func>())); res.emplace_back(std::make_pair(name, i->template getCode<Func>()));
} else { } else {
auto i = dynamic_cast<const KernelMore<KernelTuple>*>(k); auto i = dynamic_cast<const KernelMore<KernelTuple>*>(k);
PADDLE_ENFORCE_NOT_NULL(i, PADDLE_ENFORCE_NOT_NULL(
platform::errors::InvalidArgument( i, phi::errors::InvalidArgument("Kernel cast (KernelMore) failed."));
"Kernel cast (KernelMore) failed."));
res.emplace_back(std::make_pair(name, i->GetFunc())); res.emplace_back(std::make_pair(name, i->GetFunc()));
} }
} }
return res; return res;
} }
template <typename KernelTuple, typename PlaceType = platform::CPUPlace> template <typename KernelTuple, typename PlaceType = phi::CPUPlace>
std::vector<typename KernelTuple::func_type> GetAllCandidateFuncs( std::vector<typename KernelTuple::func_type> GetAllCandidateFuncs(
const typename KernelTuple::attr_type& attr) { const typename KernelTuple::attr_type& attr) {
auto funcs = GetAllCandidateFuncsWithTypes<KernelTuple, PlaceType>(attr); auto funcs = GetAllCandidateFuncsWithTypes<KernelTuple, PlaceType>(attr);
...@@ -182,13 +181,13 @@ std::vector<typename KernelTuple::func_type> GetAllCandidateFuncs( ...@@ -182,13 +181,13 @@ std::vector<typename KernelTuple::func_type> GetAllCandidateFuncs(
return res; return res;
} }
template <typename KernelTuple, typename PlaceType = platform::CPUPlace> template <typename KernelTuple, typename PlaceType = phi::CPUPlace>
typename KernelTuple::func_type GetDefaultBestFunc( typename KernelTuple::func_type GetDefaultBestFunc(
const typename KernelTuple::attr_type& attr) { const typename KernelTuple::attr_type& attr) {
auto funcs = GetAllCandidateFuncs<KernelTuple, PlaceType>(attr); auto funcs = GetAllCandidateFuncs<KernelTuple, PlaceType>(attr);
PADDLE_ENFORCE_GE(funcs.size(), PADDLE_ENFORCE_GE(funcs.size(),
1UL, 1UL,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The candicate jit kernel is at least one in CPU.")); "The candicate jit kernel is at least one in CPU."));
// Here could do some runtime benchmark of this attr and return the best one. // Here could do some runtime benchmark of this attr and return the best one.
// But yet just get the first one as the default best one, // But yet just get the first one as the default best one,
...@@ -303,5 +302,4 @@ template <typename T> ...@@ -303,5 +302,4 @@ template <typename T>
void pack_weights(const T* src, T* dst, int n, int k); void pack_weights(const T* src, T* dst, int n, int k);
} // namespace jit } // namespace jit
} // namespace operators } // namespace phi
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -15,11 +15,10 @@ ...@@ -15,11 +15,10 @@
#pragma once #pragma once
#include <cstdint> #include <cstdint>
#include "paddle/fluid/operators/jit/macro.h" #include "paddle/phi/core/macros.h"
#include "paddle/fluid/platform/macros.h" #include "paddle/phi/kernels/funcs/jit/macro.h"
namespace paddle { namespace phi {
namespace operators {
namespace jit { namespace jit {
typedef enum { typedef enum {
...@@ -403,5 +402,4 @@ class ReferKernel : public KernelMore<KernelTuple> { ...@@ -403,5 +402,4 @@ class ReferKernel : public KernelMore<KernelTuple> {
}; };
} // namespace jit } // namespace jit
} // namespace operators } // namespace phi
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -12,12 +12,11 @@ ...@@ -12,12 +12,11 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. */ * limitations under the License. */
#include "paddle/fluid/operators/jit/kernel_key.h" #include "paddle/phi/kernels/funcs/jit/kernel_key.h"
#include <xxhash.h> // XXH64: 13.8 GB/s #include <xxhash.h> // XXH64: 13.8 GB/s
namespace paddle { namespace phi {
namespace operators {
namespace jit { namespace jit {
template <> template <>
...@@ -72,5 +71,4 @@ int64_t JitCodeKey<adam_attr_t>(const adam_attr_t& attr) { ...@@ -72,5 +71,4 @@ int64_t JitCodeKey<adam_attr_t>(const adam_attr_t& attr) {
} }
} // namespace jit } // namespace jit
} // namespace operators } // namespace phi
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -13,11 +13,10 @@ ...@@ -13,11 +13,10 @@
* limitations under the License. */ * limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/operators/jit/kernel_base.h" #include "paddle/phi/common/place.h"
#include "paddle/fluid/platform/place.h" #include "paddle/phi/kernels/funcs/jit/kernel_base.h"
namespace paddle { namespace phi {
namespace operators {
namespace jit { namespace jit {
struct KernelKey { struct KernelKey {
...@@ -31,15 +30,13 @@ struct KernelKey { ...@@ -31,15 +30,13 @@ struct KernelKey {
}; };
KernelType type_; KernelType type_;
platform::Place place_; phi::Place place_;
KernelKey(KernelType type, platform::Place place) KernelKey(KernelType type, phi::Place place) : type_(type), place_(place) {}
: type_(type), place_(place) {}
size_t hash_key() const { return Hash()(*this); } size_t hash_key() const { return Hash()(*this); }
bool operator==(const KernelKey& o) const { bool operator==(const KernelKey& o) const {
return platform::places_are_same_class(place_, o.place_) && return place_ == o.place_ && type_ == o.type_;
type_ == o.type_;
} }
bool operator!=(const KernelKey& o) const { return !(*this == o); } bool operator!=(const KernelKey& o) const { return !(*this == o); }
}; };
...@@ -49,5 +46,4 @@ template <typename Attr> ...@@ -49,5 +46,4 @@ template <typename Attr>
int64_t JitCodeKey(const Attr& attr); int64_t JitCodeKey(const Attr& attr);
} // namespace jit } // namespace jit
} // namespace operators } // namespace phi
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -12,10 +12,9 @@ ...@@ -12,10 +12,9 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. */ * limitations under the License. */
#include "paddle/fluid/operators/jit/kernel_pool.h" #include "paddle/phi/kernels/funcs/jit/kernel_pool.h"
namespace paddle { namespace phi {
namespace operators {
namespace jit { namespace jit {
std::map<size_t, std::shared_ptr<void>>& GetJITCodesMap() { std::map<size_t, std::shared_ptr<void>>& GetJITCodesMap() {
...@@ -39,5 +38,4 @@ ReferKernelPool& ReferKernelPool::Instance() { ...@@ -39,5 +38,4 @@ ReferKernelPool& ReferKernelPool::Instance() {
} }
} // namespace jit } // namespace jit
} // namespace operators } // namespace phi
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -21,13 +21,12 @@ ...@@ -21,13 +21,12 @@
#include <utility> // for move #include <utility> // for move
#include <vector> #include <vector>
#include "paddle/fluid/operators/jit/gen_base.h" #include "paddle/phi/common/place.h"
#include "paddle/fluid/operators/jit/kernel_base.h" #include "paddle/phi/kernels/funcs/jit/gen_base.h"
#include "paddle/fluid/operators/jit/kernel_key.h" #include "paddle/phi/kernels/funcs/jit/kernel_base.h"
#include "paddle/fluid/platform/place.h" #include "paddle/phi/kernels/funcs/jit/kernel_key.h"
namespace paddle { namespace phi {
namespace operators {
namespace jit { namespace jit {
struct KernelKey; struct KernelKey;
...@@ -130,5 +129,4 @@ class ReferKernelPool { ...@@ -130,5 +129,4 @@ class ReferKernelPool {
}; };
} // namespace jit } // namespace jit
} // namespace operators } // namespace phi
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -15,8 +15,7 @@ ...@@ -15,8 +15,7 @@
#pragma once #pragma once
#include <type_traits> #include <type_traits>
namespace paddle { namespace phi {
namespace operators {
namespace jit { namespace jit {
#define SIGMOID_THRESHOLD_MIN -40.0 #define SIGMOID_THRESHOLD_MIN -40.0
...@@ -28,5 +27,4 @@ namespace jit { ...@@ -28,5 +27,4 @@ namespace jit {
#define ZMM_FLOAT_BLOCK 16 #define ZMM_FLOAT_BLOCK 16
} // namespace jit } // namespace jit
} // namespace operators } // namespace phi
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -12,15 +12,14 @@ ...@@ -12,15 +12,14 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. */ * limitations under the License. */
#include "paddle/fluid/operators/jit/more/intrinsic/crf_decoding.h" #include "paddle/phi/kernels/funcs/jit/more/intrinsic/crf_decoding.h"
#include <limits> #include <limits>
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/phi/backends/cpu/cpu_info.h" #include "paddle/phi/backends/cpu/cpu_info.h"
#include "paddle/phi/kernels/funcs/jit/registry.h"
namespace paddle { namespace phi {
namespace operators {
namespace jit { namespace jit {
namespace more { namespace more {
namespace intrinsic { namespace intrinsic {
...@@ -178,9 +177,8 @@ bool CRFDecodingKernel::CanBeUsed(const int& d) const { ...@@ -178,9 +177,8 @@ bool CRFDecodingKernel::CanBeUsed(const int& d) const {
} // namespace intrinsic } // namespace intrinsic
} // namespace more } // namespace more
} // namespace jit } // namespace jit
} // namespace operators } // namespace phi
} // namespace paddle
namespace intrinsic = paddle::operators::jit::more::intrinsic; namespace intrinsic = phi::jit::more::intrinsic;
REGISTER_JITKERNEL_MORE(kCRFDecoding, intrinsic, intrinsic::CRFDecodingKernel); REGISTER_JITKERNEL_MORE(kCRFDecoding, intrinsic, intrinsic::CRFDecodingKernel);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -16,10 +16,9 @@ ...@@ -16,10 +16,9 @@
#include <type_traits> #include <type_traits>
#include "paddle/fluid/operators/jit/kernel_base.h" #include "paddle/phi/kernels/funcs/jit/kernel_base.h"
namespace paddle { namespace phi {
namespace operators {
namespace jit { namespace jit {
namespace more { namespace more {
namespace intrinsic { namespace intrinsic {
...@@ -42,5 +41,4 @@ class CRFDecodingKernel : public KernelMore<CRFDecodingTuple<float>> { ...@@ -42,5 +41,4 @@ class CRFDecodingKernel : public KernelMore<CRFDecodingTuple<float>> {
} // namespace intrinsic } // namespace intrinsic
} // namespace more } // namespace more
} // namespace jit } // namespace jit
} // namespace operators } // namespace phi
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -12,15 +12,14 @@ ...@@ -12,15 +12,14 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. */ * limitations under the License. */
#include "paddle/fluid/operators/jit/more/intrinsic/layer_norm.h" #include "paddle/phi/kernels/funcs/jit/more/intrinsic/layer_norm.h"
#include <limits> #include <limits>
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/phi/backends/cpu/cpu_info.h" #include "paddle/phi/backends/cpu/cpu_info.h"
#include "paddle/phi/kernels/funcs/jit/registry.h"
namespace paddle { namespace phi {
namespace operators {
namespace jit { namespace jit {
namespace more { namespace more {
namespace intrinsic { namespace intrinsic {
...@@ -186,9 +185,8 @@ bool LayerNormKernel::CanBeUsed(const int& d) const { ...@@ -186,9 +185,8 @@ bool LayerNormKernel::CanBeUsed(const int& d) const {
} // namespace intrinsic } // namespace intrinsic
} // namespace more } // namespace more
} // namespace jit } // namespace jit
} // namespace operators } // namespace phi
} // namespace paddle
namespace intrinsic = paddle::operators::jit::more::intrinsic; namespace intrinsic = phi::jit::more::intrinsic;
REGISTER_JITKERNEL_MORE(kLayerNorm, intrinsic, intrinsic::LayerNormKernel); REGISTER_JITKERNEL_MORE(kLayerNorm, intrinsic, intrinsic::LayerNormKernel);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -16,10 +16,9 @@ ...@@ -16,10 +16,9 @@
#include <type_traits> #include <type_traits>
#include "paddle/fluid/operators/jit/kernel_base.h" #include "paddle/phi/kernels/funcs/jit/kernel_base.h"
namespace paddle { namespace phi {
namespace operators {
namespace jit { namespace jit {
namespace more { namespace more {
namespace intrinsic { namespace intrinsic {
...@@ -45,5 +44,4 @@ class LayerNormKernel : public KernelMore<LayerNormTuple<float>> { ...@@ -45,5 +44,4 @@ class LayerNormKernel : public KernelMore<LayerNormTuple<float>> {
} // namespace intrinsic } // namespace intrinsic
} // namespace more } // namespace more
} // namespace jit } // namespace jit
} // namespace operators } // namespace phi
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -12,18 +12,17 @@ ...@@ -12,18 +12,17 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. */ * limitations under the License. */
#include "paddle/fluid/operators/jit/more/mix/mix.h" #include "paddle/phi/kernels/funcs/jit/more/mix/mix.h"
#include "paddle/fluid/operators/jit/kernels.h" #include "paddle/phi/kernels/funcs/jit/kernels.h"
#include "paddle/fluid/operators/jit/registry.h" #include "paddle/phi/kernels/funcs/jit/registry.h"
namespace paddle { namespace phi {
namespace operators {
namespace jit { namespace jit {
namespace more { namespace more {
namespace mix { namespace mix {
using CPUPlace = platform::CPUPlace; using CPUPlace = phi::CPUPlace;
void VSigmoid(const T* x, T* y, int n) { void VSigmoid(const T* x, T* y, int n) {
const float min = SIGMOID_THRESHOLD_MIN; const float min = SIGMOID_THRESHOLD_MIN;
...@@ -95,7 +94,7 @@ void (*getActFunc(KernelType type, int d))(const T*, T*, int) { // NOLINT ...@@ -95,7 +94,7 @@ void (*getActFunc(KernelType type, int d))(const T*, T*, int) { // NOLINT
} else if (type == kVIdentity) { } else if (type == kVIdentity) {
return KernelFuncs<VIdentityTuple<T>, CPUPlace>::Cache().At(d); return KernelFuncs<VIdentityTuple<T>, CPUPlace>::Cache().At(d);
} }
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
"Act JIT kernel do not support type: %s", type)); "Act JIT kernel do not support type: %s", type));
return nullptr; return nullptr;
} }
...@@ -237,10 +236,9 @@ bool GRUHtPart2Kernel::CanBeUsed(const gru_attr_t& attr) const { return true; } ...@@ -237,10 +236,9 @@ bool GRUHtPart2Kernel::CanBeUsed(const gru_attr_t& attr) const { return true; }
} // namespace mix } // namespace mix
} // namespace more } // namespace more
} // namespace jit } // namespace jit
} // namespace operators } // namespace phi
} // namespace paddle
namespace mix = paddle::operators::jit::more::mix; namespace mix = phi::jit::more::mix;
#define REGISTER_MORE_KERNEL(func) \ #define REGISTER_MORE_KERNEL(func) \
REGISTER_JITKERNEL_MORE(k##func, mix, mix::func##Kernel) REGISTER_JITKERNEL_MORE(k##func, mix, mix::func##Kernel)
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -16,10 +16,9 @@ ...@@ -16,10 +16,9 @@
#include <type_traits> #include <type_traits>
#include "paddle/fluid/operators/jit/kernel_base.h" #include "paddle/phi/kernels/funcs/jit/kernel_base.h"
namespace paddle { namespace phi {
namespace operators {
namespace jit { namespace jit {
namespace more { namespace more {
namespace mix { namespace mix {
...@@ -62,5 +61,4 @@ DECLARE_MORE_KERNEL(GRUHtPart2); ...@@ -62,5 +61,4 @@ DECLARE_MORE_KERNEL(GRUHtPart2);
} // namespace mix } // namespace mix
} // namespace more } // namespace more
} // namespace jit } // namespace jit
} // namespace operators } // namespace phi
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -12,15 +12,14 @@ ...@@ -12,15 +12,14 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. */ * limitations under the License. */
#include "paddle/fluid/operators/jit/more/mkl/mkl.h" #include "paddle/phi/kernels/funcs/jit/more/mkl/mkl.h"
#include "paddle/fluid/operators/jit/refer/refer.h"
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/dynload/mklml.h"
#include "paddle/phi/backends/cpu/cpu_info.h" #include "paddle/phi/backends/cpu/cpu_info.h"
#include "paddle/phi/backends/dynload/mklml.h"
#include "paddle/phi/kernels/funcs/jit/refer/refer.h"
#include "paddle/phi/kernels/funcs/jit/registry.h"
namespace paddle { namespace phi {
namespace operators {
namespace jit { namespace jit {
namespace more { namespace more {
namespace mkl { namespace mkl {
...@@ -30,20 +29,20 @@ void MatMul<float>(const float* a, ...@@ -30,20 +29,20 @@ void MatMul<float>(const float* a,
const float* b, const float* b,
float* c, float* c,
const matmul_attr_t* attr) { const matmul_attr_t* attr) {
platform::dynload::cblas_sgemm(CblasRowMajor, phi::dynload::cblas_sgemm(CblasRowMajor,
CblasNoTrans, CblasNoTrans,
CblasNoTrans, CblasNoTrans,
attr->m, attr->m,
attr->n, attr->n,
attr->k, attr->k,
1.f, 1.f,
a, a,
attr->k, attr->k,
b, b,
attr->n, attr->n,
0.f, 0.f,
c, c,
attr->n); attr->n);
} }
template <> template <>
...@@ -51,46 +50,46 @@ void MatMul<double>(const double* a, ...@@ -51,46 +50,46 @@ void MatMul<double>(const double* a,
const double* b, const double* b,
double* c, double* c,
const matmul_attr_t* attr) { const matmul_attr_t* attr) {
platform::dynload::cblas_dgemm(CblasRowMajor, phi::dynload::cblas_dgemm(CblasRowMajor,
CblasNoTrans, CblasNoTrans,
CblasNoTrans, CblasNoTrans,
attr->m, attr->m,
attr->n, attr->n,
attr->k, attr->k,
1.0, 1.0,
a, a,
attr->k, attr->k,
b, b,
attr->n, attr->n,
0.0, 0.0,
c, c,
attr->n); attr->n);
} }
template <> template <>
void VMul<float>(const float* x, const float* y, float* z, int n) { void VMul<float>(const float* x, const float* y, float* z, int n) {
platform::dynload::vsMul(n, x, y, z); phi::dynload::vsMul(n, x, y, z);
} }
template <> template <>
void VMul<double>(const double* x, const double* y, double* z, int n) { void VMul<double>(const double* x, const double* y, double* z, int n) {
platform::dynload::vdMul(n, x, y, z); phi::dynload::vdMul(n, x, y, z);
} }
template <> template <>
void VAdd<float>(const float* x, const float* y, float* z, int n) { void VAdd<float>(const float* x, const float* y, float* z, int n) {
platform::dynload::vsAdd(n, x, y, z); phi::dynload::vsAdd(n, x, y, z);
} }
template <> template <>
void VAdd<double>(const double* x, const double* y, double* z, int n) { void VAdd<double>(const double* x, const double* y, double* z, int n) {
platform::dynload::vdAdd(n, x, y, z); phi::dynload::vdAdd(n, x, y, z);
} }
template <> template <>
void VScal<float>(const float* a, const float* x, float* y, int n) { void VScal<float>(const float* a, const float* x, float* y, int n) {
if (x == y) { if (x == y) {
platform::dynload::cblas_sscal(n, *a, y, 1); phi::dynload::cblas_sscal(n, *a, y, 1);
} else { } else {
refer::VScal<float>(a, x, y, n); refer::VScal<float>(a, x, y, n);
} }
...@@ -99,7 +98,7 @@ void VScal<float>(const float* a, const float* x, float* y, int n) { ...@@ -99,7 +98,7 @@ void VScal<float>(const float* a, const float* x, float* y, int n) {
template <> template <>
void VScal<double>(const double* a, const double* x, double* y, int n) { void VScal<double>(const double* a, const double* x, double* y, int n) {
if (x == y) { if (x == y) {
platform::dynload::cblas_dscal(n, *a, y, 1); phi::dynload::cblas_dscal(n, *a, y, 1);
} else { } else {
refer::VScal<double>(a, x, y, n); refer::VScal<double>(a, x, y, n);
} }
...@@ -109,7 +108,7 @@ template <> ...@@ -109,7 +108,7 @@ template <>
void StrideScal<float>( void StrideScal<float>(
const float* a, const float* x, float* y, int n, int stride) { const float* a, const float* x, float* y, int n, int stride) {
if (x == y) { if (x == y) {
platform::dynload::cblas_sscal(n / stride, *a, y, stride); phi::dynload::cblas_sscal(n / stride, *a, y, stride);
} else { } else {
refer::StrideScal<float>(a, x, y, n, stride); refer::StrideScal<float>(a, x, y, n, stride);
} }
...@@ -119,7 +118,7 @@ template <> ...@@ -119,7 +118,7 @@ template <>
void StrideScal<double>( void StrideScal<double>(
const double* a, const double* x, double* y, int n, int stride) { const double* a, const double* x, double* y, int n, int stride) {
if (x == y) { if (x == y) {
platform::dynload::cblas_dscal(n / stride, *a, y, stride); phi::dynload::cblas_dscal(n / stride, *a, y, stride);
} else { } else {
refer::StrideScal<double>(a, x, y, n, stride); refer::StrideScal<double>(a, x, y, n, stride);
} }
...@@ -127,62 +126,62 @@ void StrideScal<double>( ...@@ -127,62 +126,62 @@ void StrideScal<double>(
template <> template <>
void VExp<float>(const float* x, float* y, int n) { void VExp<float>(const float* x, float* y, int n) {
platform::dynload::vsExp(n, x, y); phi::dynload::vsExp(n, x, y);
} }
template <> template <>
void VExp<double>(const double* x, double* y, int n) { void VExp<double>(const double* x, double* y, int n) {
platform::dynload::vdExp(n, x, y); phi::dynload::vdExp(n, x, y);
} }
template <> template <>
void VSquare<float>(const float* x, float* y, int n) { void VSquare<float>(const float* x, float* y, int n) {
platform::dynload::vsSqr(n, x, y); phi::dynload::vsSqr(n, x, y);
} }
template <> template <>
void VSquare<double>(const double* x, double* y, int n) { void VSquare<double>(const double* x, double* y, int n) {
platform::dynload::vdSqr(n, x, y); phi::dynload::vdSqr(n, x, y);
} }
template <> template <>
void VCopy<float>(const float* x, float* y, int n) { void VCopy<float>(const float* x, float* y, int n) {
platform::dynload::cblas_scopy(n, x, 1, y, 1); phi::dynload::cblas_scopy(n, x, 1, y, 1);
} }
template <> template <>
void VCopy<double>(const double* x, double* y, int n) { void VCopy<double>(const double* x, double* y, int n) {
platform::dynload::cblas_dcopy(n, x, 1, y, 1); phi::dynload::cblas_dcopy(n, x, 1, y, 1);
} }
template <> template <>
void VAXPY<float>(float a, const float* x, float* y, int n) { void VAXPY<float>(float a, const float* x, float* y, int n) {
platform::dynload::cblas_saxpy(n, a, x, 1, y, 1); phi::dynload::cblas_saxpy(n, a, x, 1, y, 1);
} }
template <> template <>
void VAXPY<double>(double a, const double* x, double* y, int n) { void VAXPY<double>(double a, const double* x, double* y, int n) {
platform::dynload::cblas_daxpy(n, a, x, 1, y, 1); phi::dynload::cblas_daxpy(n, a, x, 1, y, 1);
} }
template <> template <>
void ASum<float>(const float* x, float* res, int n) { void ASum<float>(const float* x, float* res, int n) {
res[0] = platform::dynload::cblas_sasum(n, x, 1); res[0] = phi::dynload::cblas_sasum(n, x, 1);
} }
template <> template <>
void ASum<double>(const double* x, double* res, int n) { void ASum<double>(const double* x, double* res, int n) {
res[0] = platform::dynload::cblas_dasum(n, x, 1); res[0] = phi::dynload::cblas_dasum(n, x, 1);
} }
template <> template <>
void StrideASum<float>(const float* x, float* res, int n, int stride) { void StrideASum<float>(const float* x, float* res, int n, int stride) {
res[0] = platform::dynload::cblas_sasum(n / stride, x, stride); res[0] = phi::dynload::cblas_sasum(n / stride, x, stride);
} }
template <> template <>
void StrideASum<double>(const double* x, double* res, int n, int stride) { void StrideASum<double>(const double* x, double* res, int n, int stride) {
res[0] = platform::dynload::cblas_dasum(n / stride, x, stride); res[0] = phi::dynload::cblas_dasum(n / stride, x, stride);
} }
// TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512 // TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512
...@@ -309,10 +308,9 @@ AWALYS_USE_ME_WITH_DOUBLE(Softmax); ...@@ -309,10 +308,9 @@ AWALYS_USE_ME_WITH_DOUBLE(Softmax);
} // namespace mkl } // namespace mkl
} // namespace more } // namespace more
} // namespace jit } // namespace jit
} // namespace operators } // namespace phi
} // namespace paddle
namespace mkl = paddle::operators::jit::more::mkl; namespace mkl = phi::jit::more::mkl;
#define REGISTER_MKL_KERNEL(func) \ #define REGISTER_MKL_KERNEL(func) \
REGISTER_JITKERNEL_MORE( \ REGISTER_JITKERNEL_MORE( \
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -18,11 +18,10 @@ ...@@ -18,11 +18,10 @@
#include <type_traits> #include <type_traits>
#include <vector> #include <vector>
#include "paddle/fluid/operators/jit/kernel_base.h" #include "paddle/phi/core/enforce.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/phi/kernels/funcs/jit/kernel_base.h"
namespace paddle { namespace phi {
namespace operators {
namespace jit { namespace jit {
namespace more { namespace more {
namespace mkl { namespace mkl {
...@@ -108,7 +107,7 @@ void EmbSeqPool(const T* table, ...@@ -108,7 +107,7 @@ void EmbSeqPool(const T* table,
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
attr->table_width * attr->index_width, attr->table_width * attr->index_width,
attr->out_width, attr->out_width,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The attribute table_width * index_width of EmbSeqPool should " "The attribute table_width * index_width of EmbSeqPool should "
"be equal to out_width. But table_width * index_width is %d, " "be equal to out_width. But table_width * index_width is %d, "
"out_width is %d.", "out_width is %d.",
...@@ -118,19 +117,19 @@ void EmbSeqPool(const T* table, ...@@ -118,19 +117,19 @@ void EmbSeqPool(const T* table,
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
idx[i], idx[i],
attr->table_height, attr->table_height,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The idx shoud be lower than the attribute table_height of " "The idx shoud be lower than the attribute table_height of "
"EmbSeqPool. But %dth of idx is %d and table_height is %d.", "EmbSeqPool. But %dth of idx is %d and table_height is %d.",
i, i,
idx[i], idx[i],
attr->table_height)); attr->table_height));
PADDLE_ENFORCE_GE(idx[i], PADDLE_ENFORCE_GE(
0, idx[i],
platform::errors::InvalidArgument( 0,
"The idx shoud be equal to or larger than " phi::errors::InvalidArgument("The idx shoud be equal to or larger than "
"the 0. But %dth of idx is %d.", "the 0. But %dth of idx is %d.",
i, i,
idx[i])); idx[i]));
}; };
for (int64_t w = 0; w != attr->index_width; ++w) { for (int64_t w = 0; w != attr->index_width; ++w) {
...@@ -200,7 +199,7 @@ void Sgd(const T* lr, ...@@ -200,7 +199,7 @@ void Sgd(const T* lr,
const sgd_attr_t* attr) { const sgd_attr_t* attr) {
PADDLE_ENFORCE_EQ(attr->param_width, PADDLE_ENFORCE_EQ(attr->param_width,
attr->grad_width, attr->grad_width,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The attribute param_width of Sgd should be " "The attribute param_width of Sgd should be "
"equal to the attribute grad_width. But param_width " "equal to the attribute grad_width. But param_width "
"is %d and grad_width is %d.", "is %d and grad_width is %d.",
...@@ -208,7 +207,7 @@ void Sgd(const T* lr, ...@@ -208,7 +207,7 @@ void Sgd(const T* lr,
attr->grad_width)); attr->grad_width));
PADDLE_ENFORCE_LE(attr->selected_rows_size, PADDLE_ENFORCE_LE(attr->selected_rows_size,
attr->grad_height, attr->grad_height,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The attribute selected_rows_size of Sgd should be " "The attribute selected_rows_size of Sgd should be "
"equal to or less than the attribute grad_height. " "equal to or less than the attribute grad_height. "
"But selected_rows_size is %d and grad_height is %d.", "But selected_rows_size is %d and grad_height is %d.",
...@@ -221,7 +220,7 @@ void Sgd(const T* lr, ...@@ -221,7 +220,7 @@ void Sgd(const T* lr,
auto h_idx = rows[i]; auto h_idx = rows[i];
PADDLE_ENFORCE_LT(h_idx, PADDLE_ENFORCE_LT(h_idx,
attr->param_height, attr->param_height,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The rows of Sgd should be " "The rows of Sgd should be "
"less than the attribute. But %dth of rows " "less than the attribute. But %dth of rows "
"is %d and grad_width is %d.", "is %d and grad_width is %d.",
...@@ -231,11 +230,11 @@ void Sgd(const T* lr, ...@@ -231,11 +230,11 @@ void Sgd(const T* lr,
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
h_idx, h_idx,
0, 0,
platform::errors::InvalidArgument("The rows of Sgd should be " phi::errors::InvalidArgument("The rows of Sgd should be "
"larger than 0. But %dth of rows " "larger than 0. But %dth of rows "
"is %d.", "is %d.",
i, i,
h_idx)); h_idx));
VAXPY(scalar, grad + i * width, out + h_idx * width, width); VAXPY(scalar, grad + i * width, out + h_idx * width, width);
} }
} else { } else {
...@@ -243,7 +242,7 @@ void Sgd(const T* lr, ...@@ -243,7 +242,7 @@ void Sgd(const T* lr,
auto h_idx = rows[i]; auto h_idx = rows[i];
PADDLE_ENFORCE_LT(h_idx, PADDLE_ENFORCE_LT(h_idx,
attr->param_height, attr->param_height,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The rows of Sgd should be " "The rows of Sgd should be "
"less than the attribute. But %dth of rows " "less than the attribute. But %dth of rows "
"is %d and grad_width is %d.", "is %d and grad_width is %d.",
...@@ -253,11 +252,11 @@ void Sgd(const T* lr, ...@@ -253,11 +252,11 @@ void Sgd(const T* lr,
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
h_idx, h_idx,
0, 0,
platform::errors::InvalidArgument("The rows of Sgd should be " phi::errors::InvalidArgument("The rows of Sgd should be "
"larger than 0. But %dth of rows " "larger than 0. But %dth of rows "
"is %d.", "is %d.",
i, i,
h_idx)); h_idx));
VScal(&scalar, grad + i * width, out + h_idx * width, width); VScal(&scalar, grad + i * width, out + h_idx * width, width);
VAdd(param + h_idx * width, VAdd(param + h_idx * width,
out + h_idx * width, out + h_idx * width,
...@@ -306,5 +305,4 @@ DECLARE_MKL_KERNEL(VBroadcast); ...@@ -306,5 +305,4 @@ DECLARE_MKL_KERNEL(VBroadcast);
} // namespace mkl } // namespace mkl
} // namespace more } // namespace more
} // namespace jit } // namespace jit
} // namespace operators } // namespace phi
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -12,11 +12,11 @@ ...@@ -12,11 +12,11 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. */ * limitations under the License. */
#include "paddle/fluid/operators/jit/refer/refer.h" #include "paddle/phi/kernels/funcs/jit/refer/refer.h"
#include "paddle/fluid/operators/jit/registry.h" #include "paddle/phi/kernels/funcs/jit/registry.h"
namespace refer = paddle::operators::jit::refer; namespace refer = phi::jit::refer;
#define REGISTER_REFER_KERNEL(func) \ #define REGISTER_REFER_KERNEL(func) \
REGISTER_JITKERNEL_REFER( \ REGISTER_JITKERNEL_REFER( \
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -18,12 +18,11 @@ ...@@ -18,12 +18,11 @@
#include <limits> #include <limits>
#include <string> #include <string>
#include "paddle/fluid/operators/jit/helper.h" #include "paddle/phi/core/enforce.h"
#include "paddle/fluid/operators/jit/kernel_base.h" #include "paddle/phi/kernels/funcs/jit/helper.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/phi/kernels/funcs/jit/kernel_base.h"
namespace paddle { namespace phi {
namespace operators {
namespace jit { namespace jit {
namespace refer { namespace refer {
...@@ -147,7 +146,7 @@ void (*getActFunc(KernelType type))(const T*, T*, int) { // NOLINT ...@@ -147,7 +146,7 @@ void (*getActFunc(KernelType type))(const T*, T*, int) { // NOLINT
} else if (type == kVIdentity) { } else if (type == kVIdentity) {
return VIdentity<T>; return VIdentity<T>;
} }
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
"Act JIT kernel do not support type: %s.", type)); "Act JIT kernel do not support type: %s.", type));
return nullptr; return nullptr;
} }
...@@ -482,7 +481,7 @@ void EmbSeqPool(const T* table, ...@@ -482,7 +481,7 @@ void EmbSeqPool(const T* table,
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
attr->table_width * attr->index_width, attr->table_width * attr->index_width,
attr->out_width, attr->out_width,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The attribute table_width * index_width of EmbSeqPool should " "The attribute table_width * index_width of EmbSeqPool should "
"be equal to out_width. But table_width * index_width is %d and " "be equal to out_width. But table_width * index_width is %d and "
"out_width is %d.", "out_width is %d.",
...@@ -493,19 +492,19 @@ void EmbSeqPool(const T* table, ...@@ -493,19 +492,19 @@ void EmbSeqPool(const T* table,
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
idx[i], idx[i],
attr->table_height, attr->table_height,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The idx shoud be lower than the attribute table_height of " "The idx shoud be lower than the attribute table_height of "
"EmbSeqPool. But %dth of idx is %d and table_height is %d.", "EmbSeqPool. But %dth of idx is %d and table_height is %d.",
i, i,
idx[i], idx[i],
attr->table_height)); attr->table_height));
PADDLE_ENFORCE_GE(idx[i], PADDLE_ENFORCE_GE(
0, idx[i],
platform::errors::InvalidArgument( 0,
"The idx shoud be equal to or larger than " phi::errors::InvalidArgument("The idx shoud be equal to or larger than "
"the 0. But %dth of idx is %d.", "the 0. But %dth of idx is %d.",
i, i,
idx[i])); idx[i]));
}; };
for (int64_t w = 0; w != attr->index_width; ++w) { for (int64_t w = 0; w != attr->index_width; ++w) {
...@@ -549,7 +548,7 @@ void Sgd(const T* lr, ...@@ -549,7 +548,7 @@ void Sgd(const T* lr,
const sgd_attr_t* attr) { const sgd_attr_t* attr) {
PADDLE_ENFORCE_EQ(attr->param_width, PADDLE_ENFORCE_EQ(attr->param_width,
attr->grad_width, attr->grad_width,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The attribute param_width of Sgd should be " "The attribute param_width of Sgd should be "
"equal to the attribute grad_width. But param_width " "equal to the attribute grad_width. But param_width "
"is %d and grad_width is %d.", "is %d and grad_width is %d.",
...@@ -557,7 +556,7 @@ void Sgd(const T* lr, ...@@ -557,7 +556,7 @@ void Sgd(const T* lr,
attr->grad_width)); attr->grad_width));
PADDLE_ENFORCE_LE(attr->selected_rows_size, PADDLE_ENFORCE_LE(attr->selected_rows_size,
attr->grad_height, attr->grad_height,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The attribute selected_rows_size of Sgd should be " "The attribute selected_rows_size of Sgd should be "
"equal to or less than the attribute grad_height. " "equal to or less than the attribute grad_height. "
"But selected_rows_size is %d and grad_height is %d.", "But selected_rows_size is %d and grad_height is %d.",
...@@ -567,7 +566,7 @@ void Sgd(const T* lr, ...@@ -567,7 +566,7 @@ void Sgd(const T* lr,
auto h_idx = rows[i]; auto h_idx = rows[i];
PADDLE_ENFORCE_LT(h_idx, PADDLE_ENFORCE_LT(h_idx,
attr->param_height, attr->param_height,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The rows of Sgd should be " "The rows of Sgd should be "
"less than the attribute. But %dth of rows " "less than the attribute. But %dth of rows "
"is %d and grad_width is %d.", "is %d and grad_width is %d.",
...@@ -577,11 +576,11 @@ void Sgd(const T* lr, ...@@ -577,11 +576,11 @@ void Sgd(const T* lr,
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
h_idx, h_idx,
0, 0,
platform::errors::InvalidArgument("The rows of Sgd should be " phi::errors::InvalidArgument("The rows of Sgd should be "
"larger than 0. But %dth of rows " "larger than 0. But %dth of rows "
"is %d.", "is %d.",
i, i,
h_idx)); h_idx));
for (int64_t j = 0; j < attr->grad_width; ++j) { for (int64_t j = 0; j < attr->grad_width; ++j) {
out[h_idx * attr->grad_width + j] = out[h_idx * attr->grad_width + j] =
param[h_idx * attr->grad_width + j] - param[h_idx * attr->grad_width + j] -
...@@ -698,5 +697,4 @@ DECLARE_REFER_KERNEL(VBroadcast); ...@@ -698,5 +697,4 @@ DECLARE_REFER_KERNEL(VBroadcast);
} // namespace refer } // namespace refer
} // namespace jit } // namespace jit
} // namespace operators } // namespace phi
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -19,13 +19,12 @@ ...@@ -19,13 +19,12 @@
#include <type_traits> #include <type_traits>
#include <utility> // for std::move #include <utility> // for std::move
#include "paddle/fluid/operators/jit/kernel_base.h" #include "paddle/phi/common/place.h"
#include "paddle/fluid/operators/jit/kernel_pool.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/core/macros.h" #include "paddle/phi/core/macros.h"
#include "paddle/phi/kernels/funcs/jit/kernel_base.h"
#include "paddle/phi/kernels/funcs/jit/kernel_pool.h"
namespace paddle { namespace phi {
namespace operators {
namespace jit { namespace jit {
// make_unique is supported since c++14 // make_unique is supported since c++14
...@@ -84,23 +83,22 @@ class JitKernelRegistrar { ...@@ -84,23 +83,22 @@ class JitKernelRegistrar {
msg) msg)
// Refer always on CPUPlace // Refer always on CPUPlace
#define REGISTER_JITKERNEL_REFER(kernel_type, ...) \ #define REGISTER_JITKERNEL_REFER(kernel_type, ...) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \ STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
__reg_jitkernel_##kernel_type##_refer_CPUPlace, \ __reg_jitkernel_##kernel_type##_refer_CPUPlace, \
"REGISTER_KERNEL_REFER must be called in global namespace"); \ "REGISTER_KERNEL_REFER must be called in global namespace"); \
static ::paddle::operators::jit::JitKernelRegistrar< \ static ::phi::jit::JitKernelRegistrar<::phi::jit::ReferKernelPool, \
::paddle::operators::jit::ReferKernelPool, \ ::phi::CPUPlace, \
::paddle::platform::CPUPlace, \ __VA_ARGS__> \
__VA_ARGS__> \ __jit_kernel_registrar_##kernel_type##_refer_CPUPlace_( \
__jit_kernel_registrar_##kernel_type##_refer_CPUPlace_( \ ::phi::jit::KernelType::kernel_type); \
::paddle::operators::jit::KernelType::kernel_type); \ int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_() { \
int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_() { \ __jit_kernel_registrar_##kernel_type##_refer_CPUPlace_.Touch(); \
__jit_kernel_registrar_##kernel_type##_refer_CPUPlace_.Touch(); \ return 0; \
return 0; \
} }
// kernel_type: should be in paddle::operators::jit::KernelType // kernel_type: should be in phi::jit::KernelType
// place_type: should be one of CPUPlace and GPUPlace in paddle::platform // place_type: should be one of CPUPlace and GPUPlace in phi
#define REGISTER_KERNEL_MORE(kernel_type, impl_type, place_type, ...) \ #define REGISTER_KERNEL_MORE(kernel_type, impl_type, place_type, ...) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \ STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
__reg_jitkernel_##kernel_type##_##impl_type##_##place_type, \ __reg_jitkernel_##kernel_type##_##impl_type##_##place_type, \
...@@ -108,12 +106,11 @@ class JitKernelRegistrar { ...@@ -108,12 +106,11 @@ class JitKernelRegistrar {
extern int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \ extern int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
static int __assert_##kernel_type##_##impl_type##_##place_type##_has_refer_ \ static int __assert_##kernel_type##_##impl_type##_##place_type##_has_refer_ \
UNUSED = TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \ UNUSED = TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
static ::paddle::operators::jit::JitKernelRegistrar< \ static ::phi::jit::JitKernelRegistrar<::phi::jit::KernelPool, \
::paddle::operators::jit::KernelPool, \ ::phi::place_type, \
::paddle::platform::place_type, \ __VA_ARGS__> \
__VA_ARGS__> \
__jit_kernel_registrar_##kernel_type##_##impl_type##_##place_type##_( \ __jit_kernel_registrar_##kernel_type##_##impl_type##_##place_type##_( \
::paddle::operators::jit::KernelType::kernel_type); \ ::phi::jit::KernelType::kernel_type); \
int TouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_() { \ int TouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_() { \
__jit_kernel_registrar_##kernel_type##_##impl_type##_##place_type##_ \ __jit_kernel_registrar_##kernel_type##_##impl_type##_##place_type##_ \
.Touch(); \ .Touch(); \
...@@ -126,22 +123,21 @@ class JitKernelRegistrar { ...@@ -126,22 +123,21 @@ class JitKernelRegistrar {
#define REGISTER_GPUKERNEL_MORE(kernel_type, impl_type, ...) \ #define REGISTER_GPUKERNEL_MORE(kernel_type, impl_type, ...) \
REGISTER_KERNEL_MORE(kernel_type, impl_type, GPUPlace, __VA_ARGS__) REGISTER_KERNEL_MORE(kernel_type, impl_type, GPUPlace, __VA_ARGS__)
#define REGISTER_JITKERNEL_GEN(kernel_type, ...) \ #define REGISTER_JITKERNEL_GEN(kernel_type, ...) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \ STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
__reg_jitkernel_gen_##kernel_type##_CPUPlace_, \ __reg_jitkernel_gen_##kernel_type##_CPUPlace_, \
"REGISTER_JITKERNEL_GEN must be called in global namespace"); \ "REGISTER_JITKERNEL_GEN must be called in global namespace"); \
extern int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \ extern int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
static int __assert_gen_##kernel_type##_has_refer_ UNUSED = \ static int __assert_gen_##kernel_type##_has_refer_ UNUSED = \
TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \ TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
static ::paddle::operators::jit::JitKernelRegistrar< \ static ::phi::jit::JitKernelRegistrar<::phi::jit::JitCodeCreatorPool, \
::paddle::operators::jit::JitCodeCreatorPool, \ ::phi::CPUPlace, \
::paddle::platform::CPUPlace, \ __VA_ARGS__> \
__VA_ARGS__> \ __jit_kernel_registrar_gen_##kernel_type##_CPUPlace_( \
__jit_kernel_registrar_gen_##kernel_type##_CPUPlace_( \ ::phi::jit::KernelType::kernel_type); \
::paddle::operators::jit::KernelType::kernel_type); \ int TouchJitKernelReg_gen_##kernel_type##_CPUPlace_() { \
int TouchJitKernelReg_gen_##kernel_type##_CPUPlace_() { \ __jit_kernel_registrar_gen_##kernel_type##_CPUPlace_.Touch(); \
__jit_kernel_registrar_gen_##kernel_type##_CPUPlace_.Touch(); \ return 0; \
return 0; \
} }
#define USE_JITKERNEL_GEN(kernel_type) \ #define USE_JITKERNEL_GEN(kernel_type) \
...@@ -174,5 +170,4 @@ class JitKernelRegistrar { ...@@ -174,5 +170,4 @@ class JitKernelRegistrar {
USE_KERNEL_MORE(kernel_type, impl_type, CPUPlace) USE_KERNEL_MORE(kernel_type, impl_type, CPUPlace)
} // namespace jit } // namespace jit
} // namespace operators } // namespace phi
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -18,9 +18,10 @@ limitations under the License. */ ...@@ -18,9 +18,10 @@ limitations under the License. */
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "glog/logging.h" #include "glog/logging.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/backends/cpu/cpu_info.h" #include "paddle/phi/backends/cpu/cpu_info.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/kernels/funcs/jit/kernels.h"
DEFINE_double(acc, 1e-5, "Test accuracy threshold."); DEFINE_double(acc, 1e-5, "Test accuracy threshold.");
...@@ -62,8 +63,8 @@ std::vector<int> TestSizes() { ...@@ -62,8 +63,8 @@ std::vector<int> TestSizes() {
return s; return s;
} }
namespace jit = paddle::operators::jit; namespace jit = phi::jit;
using CPUPlace = paddle::platform::CPUPlace; using CPUPlace = phi::CPUPlace;
template <typename KernelTuple, template <typename KernelTuple,
typename PlaceType, typename PlaceType,
...@@ -1128,7 +1129,7 @@ void TestKernelSgd() { ...@@ -1128,7 +1129,7 @@ void TestKernelSgd() {
const int64_t upper) -> std::vector<int64_t> { const int64_t upper) -> std::vector<int64_t> {
PADDLE_ENFORCE_LE(static_cast<size_t>(upper - lower), PADDLE_ENFORCE_LE(static_cast<size_t>(upper - lower),
n - 1, n - 1,
paddle::platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The range of Sgd (upper - lower) should be lower " "The range of Sgd (upper - lower) should be lower "
"than n-1 (Sgd size -1). But the upper - lower is %d " "than n-1 (Sgd size -1). But the upper - lower is %d "
"and n-1 is %d.", "and n-1 is %d.",
...@@ -1137,7 +1138,7 @@ void TestKernelSgd() { ...@@ -1137,7 +1138,7 @@ void TestKernelSgd() {
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
n, n,
0, 0,
paddle::platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The Sgd size should be larger than 0. But the n is %d.", n)); "The Sgd size should be larger than 0. But the n is %d.", n));
std::vector<int64_t> all, out; std::vector<int64_t> all, out;
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册