提交 1aaec571 编写于 作者: T tensor-tang

fix enum style

test=develop
上级 facfecbd
...@@ -82,7 +82,7 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> { ...@@ -82,7 +82,7 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> {
Tensor track; Tensor track;
int* track_value = int* track_value =
track.mutable_data<int>(emission_dims, platform::CPUPlace()); track.mutable_data<int>(emission_dims, platform::CPUPlace());
auto ker = jit::Get<jit::crfdecoding, jit::CRFDecodingTuples<T>, auto ker = jit::Get<jit::kCRFDecoding, jit::CRFDecodingTuples<T>,
platform::CPUPlace>(tag_num); platform::CPUPlace>(tag_num);
ker(static_cast<int>(seq_len), x, w, alpha_value, track_value, tag_num); ker(static_cast<int>(seq_len), x, w, alpha_value, track_value, tag_num);
T max_score = -std::numeric_limits<T>::max(); T max_score = -std::numeric_limits<T>::max();
......
...@@ -108,7 +108,7 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> { ...@@ -108,7 +108,7 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
constexpr int simd_width = 16; constexpr int simd_width = 16;
int C = c / simd_width; int C = c / simd_width;
auto multiply = jit::Get<jit::nchw16cmulnc, jit::NCHW16CMulNCTuples<T>, auto multiply = jit::Get<jit::kNCHW16CMulNC, jit::NCHW16CMulNCTuples<T>,
platform::CPUPlace>(0); platform::CPUPlace>(0);
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
for (int ni = 0; ni < n; ni++) { for (int ni = 0; ni < n; ni++) {
......
...@@ -182,29 +182,29 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -182,29 +182,29 @@ class FusionGRUKernel : public framework::OpKernel<T> {
const int total_T = x_dims[0]; \ const int total_T = x_dims[0]; \
const int D3 = wh_dims[1] const int D3 = wh_dims[1]
#define INIT_OTHER_DEFINES \ #define INIT_OTHER_DEFINES \
auto* h0 = ctx.Input<Tensor>("H0"); \ auto* h0 = ctx.Input<Tensor>("H0"); \
auto* wx = ctx.Input<Tensor>("WeightX"); \ auto* wx = ctx.Input<Tensor>("WeightX"); \
auto* bias = ctx.Input<Tensor>("Bias"); \ auto* bias = ctx.Input<Tensor>("Bias"); \
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \ auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
bool is_reverse = ctx.Attr<bool>("is_reverse"); \ bool is_reverse = ctx.Attr<bool>("is_reverse"); \
const int M = x_dims[1]; \ const int M = x_dims[1]; \
const int D = wh_dims[0]; \ const int D = wh_dims[0]; \
const int D2 = D * 2; \ const int D2 = D * 2; \
const jit::gru_attr_t attr( \ const jit::gru_attr_t attr( \
D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")), \ D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")), \
jit::to_kerneltype(ctx.Attr<std::string>("activation"))); \ jit::to_kerneltype(ctx.Attr<std::string>("activation"))); \
jit::gru_t one_step; \ jit::gru_t one_step; \
auto ComputeH1 = \ auto ComputeH1 = \
jit::Get<jit::gruh1, jit::GRUTuples<T>, platform::CPUPlace>(attr); \ jit::Get<jit::kGRUH1, jit::GRUTuples<T>, platform::CPUPlace>(attr); \
auto ComputeHtPart1 = \ auto ComputeHtPart1 = \
jit::Get<jit::gruhtpart1, jit::GRUTuples<T>, platform::CPUPlace>(attr); \ jit::Get<jit::kGRUHtPart1, jit::GRUTuples<T>, platform::CPUPlace>(attr); \
auto ComputeHtPart2 = \ auto ComputeHtPart2 = \
jit::Get<jit::gruhtpart2, jit::GRUTuples<T>, platform::CPUPlace>(attr); \ jit::Get<jit::kGRUHtPart2, jit::GRUTuples<T>, platform::CPUPlace>(attr); \
const T* x_data = x->data<T>(); \ const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \ const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \ const T* wh_data = wh->data<T>(); \
auto place = ctx.GetPlace(); \ auto place = ctx.GetPlace(); \
T* xx_data = xx->mutable_data<T>(place) T* xx_data = xx->mutable_data<T>(place)
void SeqCompute(const framework::ExecutionContext& ctx) const { void SeqCompute(const framework::ExecutionContext& ctx) const {
......
...@@ -235,32 +235,32 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -235,32 +235,32 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
const int D = wh_dims[0]; \ const int D = wh_dims[0]; \
const int D4 = wh_dims[1] const int D4 = wh_dims[1]
#define INIT_OTHER_DEFINES \ #define INIT_OTHER_DEFINES \
const T* x_data = x->data<T>(); \ const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \ const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \ const T* wh_data = wh->data<T>(); \
/* diagonal weight*/ \ /* diagonal weight*/ \
const T* wp_data = bias->data<T>() + D4; \ const T* wp_data = bias->data<T>() + D4; \
/* for peephole only*/ \ /* for peephole only*/ \
T* checked_cell_data = nullptr; \ T* checked_cell_data = nullptr; \
auto place = ctx.GetPlace(); \ auto place = ctx.GetPlace(); \
if (use_peepholes) { \ if (use_peepholes) { \
/* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \ /* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \
auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \ auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \
checked_cell_data = checked_cell->mutable_data<T>(place); \ checked_cell_data = checked_cell->mutable_data<T>(place); \
} \ } \
const jit::lstm_attr_t attr( \ const jit::lstm_attr_t attr( \
D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")), \ D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")), \
jit::to_kerneltype(ctx.Attr<std::string>("candidate_activation")), \ jit::to_kerneltype(ctx.Attr<std::string>("candidate_activation")), \
jit::to_kerneltype(ctx.Attr<std::string>("cell_activation")), \ jit::to_kerneltype(ctx.Attr<std::string>("cell_activation")), \
use_peepholes); \ use_peepholes); \
jit::lstm_t one_step; \ jit::lstm_t one_step; \
one_step.wp = wp_data; \ one_step.wp = wp_data; \
one_step.checked = checked_cell_data; \ one_step.checked = checked_cell_data; \
auto ComputeC1H1 = \ auto ComputeC1H1 = \
jit::Get<jit::lstmc1h1, jit::LSTMTuples<T>, platform::CPUPlace>(attr); \ jit::Get<jit::kLSTMC1H1, jit::LSTMTuples<T>, platform::CPUPlace>(attr); \
auto ComputeCtHt = \ auto ComputeCtHt = \
jit::Get<jit::lstmctht, jit::LSTMTuples<T>, platform::CPUPlace>(attr) jit::Get<jit::kLSTMCtHt, jit::LSTMTuples<T>, platform::CPUPlace>(attr)
// Wh GEMM // Wh GEMM
#define GEMM_WH_ADDON(bs, prev, out) \ #define GEMM_WH_ADDON(bs, prev, out) \
......
...@@ -146,7 +146,7 @@ template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType> ...@@ -146,7 +146,7 @@ template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void BenchLSTMKernel() { void BenchLSTMKernel() {
for (bool use_peephole : {true, false}) { for (bool use_peephole : {true, false}) {
for (int d : TestSizes()) { for (int d : TestSizes()) {
const jit::lstm_attr_t attr(d, jit::vsigmoid, jit::vtanh, jit::vtanh, const jit::lstm_attr_t attr(d, jit::kVSigmoid, jit::kVTanh, jit::kVTanh,
use_peephole); use_peephole);
std::vector<T> x(4 * d), ct_1(d), ct(d), ht(d), wp(3 * d), checked(2 * d); std::vector<T> x(4 * d), ct_1(d), ct(d), ht(d), wp(3 * d), checked(2 * d);
RandomVec<T>(4 * d, x.data(), -2.f, 2.f); RandomVec<T>(4 * d, x.data(), -2.f, 2.f);
...@@ -175,7 +175,7 @@ void BenchLSTMKernel() { ...@@ -175,7 +175,7 @@ void BenchLSTMKernel() {
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType> template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void BenchGRUKernel() { void BenchGRUKernel() {
for (int d : TestSizes()) { for (int d : TestSizes()) {
const jit::gru_attr_t attr(d, jit::vsigmoid, jit::vtanh); const jit::gru_attr_t attr(d, jit::kVSigmoid, jit::kVTanh);
std::vector<T> x(3 * d), ht_1(d), ht(d); std::vector<T> x(3 * d), ht_1(d), ht(d);
RandomVec<T>(3 * d, x.data(), -2.f, 2.f); RandomVec<T>(3 * d, x.data(), -2.f, 2.f);
RandomVec<T>(d, ht_1.data(), -2.f, 2.f); RandomVec<T>(d, ht_1.data(), -2.f, 2.f);
...@@ -204,28 +204,28 @@ int main(int argc, char* argv[]) { ...@@ -204,28 +204,28 @@ int main(int argc, char* argv[]) {
using T = float; using T = float;
using PlaceType = paddle::platform::CPUPlace; using PlaceType = paddle::platform::CPUPlace;
// xyzn // xyzn
BenchXYZNKernel<jit::vmul, T, PlaceType>(); BenchXYZNKernel<jit::kVMul, T, PlaceType>();
BenchXYZNKernel<jit::vadd, T, PlaceType>(); BenchXYZNKernel<jit::kVAdd, T, PlaceType>();
BenchXYZNKernel<jit::vaddrelu, T, PlaceType>(); BenchXYZNKernel<jit::kVAddRelu, T, PlaceType>();
BenchXYZNKernel<jit::vsub, T, PlaceType>(); BenchXYZNKernel<jit::kVSub, T, PlaceType>();
// axyn // axyn
BenchAXYNKernel<jit::vscal, T, PlaceType>(); BenchAXYNKernel<jit::kVScal, T, PlaceType>();
BenchAXYNKernel<jit::vaddbias, T, PlaceType>(); BenchAXYNKernel<jit::kVAddBias, T, PlaceType>();
// xyn // xyn
BenchXYNKernel<jit::vrelu, T, PlaceType>(); BenchXYNKernel<jit::kVRelu, T, PlaceType>();
BenchXYNKernel<jit::videntity, T, PlaceType>(); BenchXYNKernel<jit::kVIdentity, T, PlaceType>();
BenchXYNKernel<jit::vexp, T, PlaceType>(); BenchXYNKernel<jit::kVExp, T, PlaceType>();
BenchXYNKernel<jit::vsigmoid, T, PlaceType>(); BenchXYNKernel<jit::kVSigmoid, T, PlaceType>();
BenchXYNKernel<jit::vtanh, T, PlaceType>(); BenchXYNKernel<jit::kVTanh, T, PlaceType>();
// lstm and peephole // lstm and peephole
BenchLSTMKernel<jit::lstmctht, T, PlaceType>(); BenchLSTMKernel<jit::kLSTMCtHt, T, PlaceType>();
BenchLSTMKernel<jit::lstmc1h1, T, PlaceType>(); BenchLSTMKernel<jit::kLSTMC1H1, T, PlaceType>();
// gru functions // gru functions
BenchGRUKernel<jit::gruh1, T, PlaceType>(); BenchGRUKernel<jit::kGRUH1, T, PlaceType>();
BenchGRUKernel<jit::gruhtpart1, T, PlaceType>(); BenchGRUKernel<jit::kGRUHtPart1, T, PlaceType>();
BenchGRUKernel<jit::gruhtpart2, T, PlaceType>(); BenchGRUKernel<jit::kGRUHtPart2, T, PlaceType>();
} }
...@@ -9,20 +9,20 @@ function(USE_JITKERNEL_GEN TARGET) ...@@ -9,20 +9,20 @@ function(USE_JITKERNEL_GEN TARGET)
endfunction() endfunction()
# use gen jitcode kernel by name # use gen jitcode kernel by name
USE_JITKERNEL_GEN(vmul) USE_JITKERNEL_GEN(kVMul)
USE_JITKERNEL_GEN(vadd) USE_JITKERNEL_GEN(kVAdd)
#USE_JITKERNEL_GEN(vsub) # TODO(TJ): enable me #USE_JITKERNEL_GEN(kVSub) # TODO(TJ): enable me
USE_JITKERNEL_GEN(vaddrelu) USE_JITKERNEL_GEN(kVAddRelu)
USE_JITKERNEL_GEN(vscal) USE_JITKERNEL_GEN(kVScal)
USE_JITKERNEL_GEN(vaddbias) USE_JITKERNEL_GEN(kVAddBias)
USE_JITKERNEL_GEN(vrelu) USE_JITKERNEL_GEN(kVRelu)
USE_JITKERNEL_GEN(videntity) USE_JITKERNEL_GEN(kVIdentity)
USE_JITKERNEL_GEN(vexp) USE_JITKERNEL_GEN(kVExp)
USE_JITKERNEL_GEN(vsigmoid) USE_JITKERNEL_GEN(kVSigmoid)
USE_JITKERNEL_GEN(vtanh) USE_JITKERNEL_GEN(kVTanh)
USE_JITKERNEL_GEN(lstmctht) USE_JITKERNEL_GEN(kLSTMCtHt)
USE_JITKERNEL_GEN(lstmc1h1) USE_JITKERNEL_GEN(kLSTMC1H1)
USE_JITKERNEL_GEN(gruh1) USE_JITKERNEL_GEN(kGRUH1)
USE_JITKERNEL_GEN(gruhtpart1) USE_JITKERNEL_GEN(kGRUHtPart1)
USE_JITKERNEL_GEN(gruhtpart2) USE_JITKERNEL_GEN(kGRUHtPart2)
USE_JITKERNEL_GEN(nchw16cmulnc) USE_JITKERNEL_GEN(kNCHW16CMulNC)
...@@ -128,8 +128,8 @@ size_t VTanhCreator::CodeSize(const int& d) const { ...@@ -128,8 +128,8 @@ size_t VTanhCreator::CodeSize(const int& d) const {
namespace gen = paddle::operators::jit::gen; namespace gen = paddle::operators::jit::gen;
REGISTER_JITKERNEL_GEN(vrelu, gen::VReluCreator); REGISTER_JITKERNEL_GEN(kVRelu, gen::VReluCreator);
REGISTER_JITKERNEL_GEN(videntity, gen::VIdentityCreator); REGISTER_JITKERNEL_GEN(kVIdentity, gen::VIdentityCreator);
REGISTER_JITKERNEL_GEN(vexp, gen::VExpCreator); REGISTER_JITKERNEL_GEN(kVExp, gen::VExpCreator);
REGISTER_JITKERNEL_GEN(vsigmoid, gen::VSigmoidCreator); REGISTER_JITKERNEL_GEN(kVSigmoid, gen::VSigmoidCreator);
REGISTER_JITKERNEL_GEN(vtanh, gen::VTanhCreator); REGISTER_JITKERNEL_GEN(kVTanh, gen::VTanhCreator);
...@@ -176,11 +176,11 @@ DECLARE_BLAS_CREATOR(VAddBias); ...@@ -176,11 +176,11 @@ DECLARE_BLAS_CREATOR(VAddBias);
namespace gen = paddle::operators::jit::gen; namespace gen = paddle::operators::jit::gen;
REGISTER_JITKERNEL_GEN(vmul, gen::VMulCreator); REGISTER_JITKERNEL_GEN(kVMul, gen::VMulCreator);
REGISTER_JITKERNEL_GEN(vadd, gen::VAddCreator); REGISTER_JITKERNEL_GEN(kVAdd, gen::VAddCreator);
// TODO(TJ): enable sub // TODO(TJ): enable sub
// REGISTER_JITKERNEL_GEN(vsub, gen::VSubCreator); // REGISTER_JITKERNEL_GEN(kVSub, gen::VSubCreator);
REGISTER_JITKERNEL_GEN(vaddrelu, gen::VAddReluCreator); REGISTER_JITKERNEL_GEN(kVAddRelu, gen::VAddReluCreator);
REGISTER_JITKERNEL_GEN(vscal, gen::VScalCreator); REGISTER_JITKERNEL_GEN(kVScal, gen::VScalCreator);
REGISTER_JITKERNEL_GEN(vaddbias, gen::VAddBiasCreator); REGISTER_JITKERNEL_GEN(kVAddBias, gen::VAddBiasCreator);
REGISTER_JITKERNEL_GEN(nchw16cmulnc, gen::NCHW16CMulNCCreator); REGISTER_JITKERNEL_GEN(kNCHW16CMulNC, gen::NCHW16CMulNCCreator);
...@@ -111,6 +111,6 @@ DECLARE_GRU_CREATOR(GRUHtPart2); ...@@ -111,6 +111,6 @@ DECLARE_GRU_CREATOR(GRUHtPart2);
namespace gen = paddle::operators::jit::gen; namespace gen = paddle::operators::jit::gen;
REGISTER_JITKERNEL_GEN(gruh1, gen::GRUH1Creator); REGISTER_JITKERNEL_GEN(kGRUH1, gen::GRUH1Creator);
REGISTER_JITKERNEL_GEN(gruhtpart1, gen::GRUHtPart1Creator); REGISTER_JITKERNEL_GEN(kGRUHtPart1, gen::GRUHtPart1Creator);
REGISTER_JITKERNEL_GEN(gruhtpart2, gen::GRUHtPart2Creator); REGISTER_JITKERNEL_GEN(kGRUHtPart2, gen::GRUHtPart2Creator);
...@@ -30,13 +30,13 @@ class GRUJitCode : public VActFunc { ...@@ -30,13 +30,13 @@ class GRUJitCode : public VActFunc {
void* code_ptr = nullptr) void* code_ptr = nullptr)
: VActFunc(code_size, code_ptr), id_(id), num_(attr.d) { : VActFunc(code_size, code_ptr), id_(id), num_(attr.d) {
auto typeExchange = [](KernelType type) -> gen::operand_type { auto typeExchange = [](KernelType type) -> gen::operand_type {
if (type == KernelType::vsigmoid) { if (type == KernelType::kVSigmoid) {
return operand_type::SIGMOID; return operand_type::SIGMOID;
} else if (type == KernelType::vrelu) { } else if (type == KernelType::kVRelu) {
return operand_type::RELU; return operand_type::RELU;
} else if (type == KernelType::vtanh) { } else if (type == KernelType::kVTanh) {
return operand_type::TANH; return operand_type::TANH;
} else if (type == KernelType::videntity) { } else if (type == KernelType::kVIdentity) {
return operand_type::IDENTITY; return operand_type::IDENTITY;
} else { } else {
LOG(FATAL) << "Do not support this jit::KernelType: " << type; LOG(FATAL) << "Do not support this jit::KernelType: " << type;
......
...@@ -138,5 +138,5 @@ DECLARE_LSTM_CREATOR(LSTMC1H1); ...@@ -138,5 +138,5 @@ DECLARE_LSTM_CREATOR(LSTMC1H1);
namespace gen = paddle::operators::jit::gen; namespace gen = paddle::operators::jit::gen;
REGISTER_JITKERNEL_GEN(lstmctht, gen::LSTMCtHtCreator); REGISTER_JITKERNEL_GEN(kLSTMCtHt, gen::LSTMCtHtCreator);
REGISTER_JITKERNEL_GEN(lstmc1h1, gen::LSTMC1H1Creator); REGISTER_JITKERNEL_GEN(kLSTMC1H1, gen::LSTMC1H1Creator);
...@@ -33,13 +33,13 @@ class LSTMJitCode : public VActFunc { ...@@ -33,13 +33,13 @@ class LSTMJitCode : public VActFunc {
compute_c1h1_(compute_c1h1), compute_c1h1_(compute_c1h1),
use_peephole_(attr.use_peephole) { use_peephole_(attr.use_peephole) {
auto typeExchange = [](KernelType type) -> gen::operand_type { auto typeExchange = [](KernelType type) -> gen::operand_type {
if (type == KernelType::vsigmoid) { if (type == KernelType::kVSigmoid) {
return operand_type::SIGMOID; return operand_type::SIGMOID;
} else if (type == KernelType::vrelu) { } else if (type == KernelType::kVRelu) {
return operand_type::RELU; return operand_type::RELU;
} else if (type == KernelType::vtanh) { } else if (type == KernelType::kVTanh) {
return operand_type::TANH; return operand_type::TANH;
} else if (type == KernelType::videntity) { } else if (type == KernelType::kVIdentity) {
return operand_type::IDENTITY; return operand_type::IDENTITY;
} else { } else {
LOG(FATAL) << "Do not support this jit::KernelType: " << type; LOG(FATAL) << "Do not support this jit::KernelType: " << type;
......
...@@ -26,25 +26,25 @@ namespace jit { ...@@ -26,25 +26,25 @@ namespace jit {
const char* to_string(KernelType kt) { const char* to_string(KernelType kt) {
switch (kt) { switch (kt) {
ONE_CASE(vmul); ONE_CASE(kVMul);
ONE_CASE(vadd); ONE_CASE(kVAdd);
ONE_CASE(vaddrelu); ONE_CASE(kVAddRelu);
ONE_CASE(vsub); ONE_CASE(kVSub);
ONE_CASE(vscal); ONE_CASE(kVScal);
ONE_CASE(vaddbias); ONE_CASE(kVAddBias);
ONE_CASE(vrelu); ONE_CASE(kVRelu);
ONE_CASE(videntity); ONE_CASE(kVIdentity);
ONE_CASE(vexp); ONE_CASE(kVExp);
ONE_CASE(vsigmoid); ONE_CASE(kVSigmoid);
ONE_CASE(vtanh); ONE_CASE(kVTanh);
ONE_CASE(lstmctht); ONE_CASE(kLSTMCtHt);
ONE_CASE(lstmc1h1); ONE_CASE(kLSTMC1H1);
ONE_CASE(gruh1); ONE_CASE(kGRUH1);
ONE_CASE(gruhtpart1); ONE_CASE(kGRUHtPart1);
ONE_CASE(gruhtpart2); ONE_CASE(kGRUHtPart2);
ONE_CASE(crfdecoding); ONE_CASE(kCRFDecoding);
ONE_CASE(layernorm); ONE_CASE(kLayerNorm);
ONE_CASE(nchw16cmulnc); ONE_CASE(kNCHW16CMulNC);
default: default:
PADDLE_THROW("Not support type: %d, or forget to add it.", kt); PADDLE_THROW("Not support type: %d, or forget to add it.", kt);
return "NOT JITKernel"; return "NOT JITKernel";
...@@ -57,19 +57,18 @@ KernelType to_kerneltype(const std::string& act) { ...@@ -57,19 +57,18 @@ KernelType to_kerneltype(const std::string& act) {
std::string lower = act; std::string lower = act;
std::transform(lower.begin(), lower.end(), lower.begin(), ::tolower); std::transform(lower.begin(), lower.end(), lower.begin(), ::tolower);
if (lower == "relu" || lower == "vrelu") { if (lower == "relu" || lower == "vrelu") {
return vrelu; return kVRelu;
} else if (lower == "identity" || lower == "videntity" || lower == "") { } else if (lower == "identity" || lower == "videntity" || lower == "") {
return videntity; return kVIdentity;
} else if (lower == "exp" || lower == "vexp") { } else if (lower == "exp" || lower == "vexp") {
return vexp; return kVExp;
} else if (lower == "sigmoid" || lower == "vsigmoid") { } else if (lower == "sigmoid" || lower == "vsigmoid") {
return vsigmoid; return kVSigmoid;
} else if (lower == "tanh" || lower == "vtanh") { } else if (lower == "tanh" || lower == "vtanh") {
return vtanh; return kVTanh;
} }
PADDLE_THROW("Not support type: %s, or forget to add this case", act); PADDLE_THROW("Not support type: %s, or forget to add this case", act);
return kNone;
return non_kernel;
} }
} // namespace jit } // namespace jit
......
...@@ -21,26 +21,26 @@ namespace operators { ...@@ -21,26 +21,26 @@ namespace operators {
namespace jit { namespace jit {
typedef enum { typedef enum {
non_kernel = 0, kNone = 0,
vmul = 1, kVMul = 1,
vadd = 2, kVAdd = 2,
vaddrelu, kVAddRelu,
vsub, kVSub,
vscal, kVScal,
vaddbias, kVAddBias,
vrelu, kVRelu,
videntity, kVIdentity,
vexp, kVExp,
vsigmoid, kVSigmoid,
vtanh, kVTanh,
lstmctht, kLSTMCtHt,
lstmc1h1, kLSTMC1H1,
gruh1, kGRUH1,
gruhtpart1, kGRUHtPart1,
gruhtpart2, kGRUHtPart2,
crfdecoding, kCRFDecoding,
layernorm, kLayerNorm,
nchw16cmulnc, kNCHW16CMulNC,
} KernelType; } KernelType;
template <typename T> template <typename T>
......
...@@ -5,5 +5,5 @@ cc_library(jit_kernel_intrinsic SRCS ${jit_kernel_cc_intrinsic} DEPS jit_kernel_ ...@@ -5,5 +5,5 @@ cc_library(jit_kernel_intrinsic SRCS ${jit_kernel_cc_intrinsic} DEPS jit_kernel_
set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} jit_kernel_intrinsic PARENT_SCOPE) set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} jit_kernel_intrinsic PARENT_SCOPE)
# use mkl kernels by name and type # use mkl kernels by name and type
USE_JITKERNEL_MORE(crfdecoding, intrinsic) USE_JITKERNEL_MORE(kCRFDecoding, intrinsic)
USE_JITKERNEL_MORE(layernorm, intrinsic) USE_JITKERNEL_MORE(kLayerNorm, intrinsic)
...@@ -178,4 +178,4 @@ bool CRFDecodingKernel::UseMe(const int& d) const { ...@@ -178,4 +178,4 @@ bool CRFDecodingKernel::UseMe(const int& d) const {
namespace intrinsic = paddle::operators::jit::more::intrinsic; namespace intrinsic = paddle::operators::jit::more::intrinsic;
REGISTER_JITKERNEL_MORE(crfdecoding, intrinsic, intrinsic::CRFDecodingKernel); REGISTER_JITKERNEL_MORE(kCRFDecoding, intrinsic, intrinsic::CRFDecodingKernel);
...@@ -165,4 +165,4 @@ bool LayerNormKernel::UseMe(const int& d) const { ...@@ -165,4 +165,4 @@ bool LayerNormKernel::UseMe(const int& d) const {
namespace intrinsic = paddle::operators::jit::more::intrinsic; namespace intrinsic = paddle::operators::jit::more::intrinsic;
REGISTER_JITKERNEL_MORE(layernorm, intrinsic, intrinsic::LayerNormKernel); REGISTER_JITKERNEL_MORE(kLayerNorm, intrinsic, intrinsic::LayerNormKernel);
...@@ -5,10 +5,10 @@ cc_library(jit_kernel_mix SRCS ${jit_kernel_mix_cc} DEPS jit_kernel_base) ...@@ -5,10 +5,10 @@ cc_library(jit_kernel_mix SRCS ${jit_kernel_mix_cc} DEPS jit_kernel_base)
set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} jit_kernel_mix PARENT_SCOPE) set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} jit_kernel_mix PARENT_SCOPE)
USE_JITKERNEL_MORE(vsigmoid, mix) USE_JITKERNEL_MORE(kVSigmoid, mix)
USE_JITKERNEL_MORE(vtanh, mix) USE_JITKERNEL_MORE(kVTanh, mix)
USE_JITKERNEL_MORE(lstmctht, mix) USE_JITKERNEL_MORE(kLSTMCtHt, mix)
USE_JITKERNEL_MORE(lstmc1h1, mix) USE_JITKERNEL_MORE(kLSTMC1H1, mix)
USE_JITKERNEL_MORE(gruh1, mix) USE_JITKERNEL_MORE(kGRUH1, mix)
USE_JITKERNEL_MORE(gruhtpart1, mix) USE_JITKERNEL_MORE(kGRUHtPart1, mix)
USE_JITKERNEL_MORE(gruhtpart2, mix) USE_JITKERNEL_MORE(kGRUHtPart2, mix)
...@@ -30,7 +30,7 @@ void VSigmoid(const T* x, T* y, int n) { ...@@ -30,7 +30,7 @@ void VSigmoid(const T* x, T* y, int n) {
y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]); y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
y[i] = static_cast<T>(0) - y[i]; y[i] = static_cast<T>(0) - y[i];
} }
auto compute = Get<KernelType::vexp, XYNTuples<T>, platform::CPUPlace>(n); auto compute = Get<KernelType::kVExp, XYNTuples<T>, platform::CPUPlace>(n);
compute(y, y, n); compute(y, y, n);
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
y[i] = static_cast<T>(1) / (static_cast<T>(1) + y[i]); y[i] = static_cast<T>(1) / (static_cast<T>(1) + y[i]);
...@@ -39,9 +39,9 @@ void VSigmoid(const T* x, T* y, int n) { ...@@ -39,9 +39,9 @@ void VSigmoid(const T* x, T* y, int n) {
void VTanh(const T* x, T* y, int n) { void VTanh(const T* x, T* y, int n) {
const T a = 2, b = -1; const T a = 2, b = -1;
auto compute_scal = Get<vscal, AXYNTuples<T>, platform::CPUPlace>(n); auto compute_scal = Get<kVScal, AXYNTuples<T>, platform::CPUPlace>(n);
auto compute_addbias = Get<vaddbias, AXYNTuples<T>, platform::CPUPlace>(n); auto compute_addbias = Get<kVAddBias, AXYNTuples<T>, platform::CPUPlace>(n);
auto compute_sigmoid = Get<vsigmoid, XYNTuples<T>, platform::CPUPlace>(n); auto compute_sigmoid = Get<kVSigmoid, XYNTuples<T>, platform::CPUPlace>(n);
compute_scal(&a, x, y, n); compute_scal(&a, x, y, n);
compute_sigmoid(y, y, n); compute_sigmoid(y, y, n);
compute_scal(&a, y, y, n); compute_scal(&a, y, y, n);
...@@ -49,14 +49,14 @@ void VTanh(const T* x, T* y, int n) { ...@@ -49,14 +49,14 @@ void VTanh(const T* x, T* y, int n) {
} }
void (*getActFunc(KernelType type, int d))(const T*, T*, int) { // NOLINT void (*getActFunc(KernelType type, int d))(const T*, T*, int) { // NOLINT
if (type == vsigmoid) { if (type == kVSigmoid) {
return Get<vsigmoid, XYNTuples<T>, platform::CPUPlace>(d); return Get<kVSigmoid, XYNTuples<T>, platform::CPUPlace>(d);
} else if (type == vrelu) { } else if (type == kVRelu) {
return Get<vrelu, XYNTuples<T>, platform::CPUPlace>(d); return Get<kVRelu, XYNTuples<T>, platform::CPUPlace>(d);
} else if (type == vtanh) { } else if (type == kVTanh) {
return Get<vtanh, XYNTuples<T>, platform::CPUPlace>(d); return Get<kVTanh, XYNTuples<T>, platform::CPUPlace>(d);
} else if (type == videntity) { } else if (type == kVIdentity) {
return Get<videntity, XYNTuples<T>, platform::CPUPlace>(d); return Get<kVIdentity, XYNTuples<T>, platform::CPUPlace>(d);
} }
PADDLE_THROW("Not support type: %s", type); PADDLE_THROW("Not support type: %s", type);
return nullptr; return nullptr;
...@@ -72,9 +72,9 @@ void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr) { ...@@ -72,9 +72,9 @@ void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr) {
const int d = attr->d; const int d = attr->d;
const int d2 = d * 2; const int d2 = d * 2;
const int d3 = d * 3; const int d3 = d * 3;
auto vmul_d = Get<vmul, XYZNTuples<T>, platform::CPUPlace>(d); auto vmul_d = Get<kVMul, XYZNTuples<T>, platform::CPUPlace>(d);
auto vadd_d = Get<vadd, XYZNTuples<T>, platform::CPUPlace>(d); auto vadd_d = Get<kVAdd, XYZNTuples<T>, platform::CPUPlace>(d);
auto vadd_d2 = Get<vadd, XYZNTuples<T>, platform::CPUPlace>(d2); auto vadd_d2 = Get<kVAdd, XYZNTuples<T>, platform::CPUPlace>(d2);
auto act_gate_d = getActFunc(attr->act_gate, d); auto act_gate_d = getActFunc(attr->act_gate, d);
auto act_gate_d2 = getActFunc(attr->act_gate, d2); auto act_gate_d2 = getActFunc(attr->act_gate, d2);
auto act_gate_d3 = getActFunc(attr->act_gate, d3); auto act_gate_d3 = getActFunc(attr->act_gate, d3);
...@@ -114,8 +114,8 @@ void LSTMC1H1(lstm_t* step, const lstm_attr_t* attr) { ...@@ -114,8 +114,8 @@ void LSTMC1H1(lstm_t* step, const lstm_attr_t* attr) {
int d = attr->d; int d = attr->d;
int d2 = d * 2; int d2 = d * 2;
int d3 = d * 3; int d3 = d * 3;
auto vmul_d = Get<vmul, XYZNTuples<T>, platform::CPUPlace>(d); auto vmul_d = Get<kVMul, XYZNTuples<T>, platform::CPUPlace>(d);
auto vadd_d = Get<vadd, XYZNTuples<T>, platform::CPUPlace>(d); auto vadd_d = Get<kVAdd, XYZNTuples<T>, platform::CPUPlace>(d);
auto act_gate_d = getActFunc(attr->act_gate, d); auto act_gate_d = getActFunc(attr->act_gate, d);
auto act_cand_d = getActFunc(attr->act_cand, d); auto act_cand_d = getActFunc(attr->act_cand, d);
auto act_cell_d = getActFunc(attr->act_cell, d); auto act_cell_d = getActFunc(attr->act_cell, d);
...@@ -143,7 +143,7 @@ void GRUH1(gru_t* step, const gru_attr_t* attr) { ...@@ -143,7 +143,7 @@ void GRUH1(gru_t* step, const gru_attr_t* attr) {
int d2 = d * 2; int d2 = d * 2;
auto act_gate = getActFunc(attr->act_gate, d); auto act_gate = getActFunc(attr->act_gate, d);
auto act_cand = getActFunc(attr->act_cand, d); auto act_cand = getActFunc(attr->act_cand, d);
auto vmul_d = Get<vmul, XYZNTuples<T>, platform::CPUPlace>(d); auto vmul_d = Get<kVMul, XYZNTuples<T>, platform::CPUPlace>(d);
act_gate(gates, gates, d); act_gate(gates, gates, d);
act_cand(gates + d2, gates + d2, d); act_cand(gates + d2, gates + d2, d);
vmul_d(gates, gates + d2, ht, d); vmul_d(gates, gates + d2, ht, d);
...@@ -156,7 +156,7 @@ void GRUHtPart1(gru_t* step, const gru_attr_t* attr) { ...@@ -156,7 +156,7 @@ void GRUHtPart1(gru_t* step, const gru_attr_t* attr) {
T* ht = reinterpret_cast<T*>(step->ht); T* ht = reinterpret_cast<T*>(step->ht);
const T* ht_1 = reinterpret_cast<const T*>(step->ht_1); const T* ht_1 = reinterpret_cast<const T*>(step->ht_1);
auto act_gate = getActFunc(attr->act_gate, attr->d); auto act_gate = getActFunc(attr->act_gate, attr->d);
auto vmul_d = Get<vmul, XYZNTuples<T>, platform::CPUPlace>(attr->d); auto vmul_d = Get<kVMul, XYZNTuples<T>, platform::CPUPlace>(attr->d);
act_gate(gates + attr->d, gates + attr->d, attr->d); act_gate(gates + attr->d, gates + attr->d, attr->d);
vmul_d(ht_1, gates + attr->d, ht, attr->d); vmul_d(ht_1, gates + attr->d, ht, attr->d);
} }
...@@ -205,12 +205,12 @@ namespace mix = paddle::operators::jit::more::mix; ...@@ -205,12 +205,12 @@ namespace mix = paddle::operators::jit::more::mix;
#define REGISTER_MORE_KERNEL(key, func) \ #define REGISTER_MORE_KERNEL(key, func) \
REGISTER_JITKERNEL_MORE(key, mix, mix::func##Kernel) REGISTER_JITKERNEL_MORE(key, mix, mix::func##Kernel)
REGISTER_MORE_KERNEL(vsigmoid, VSigmoid); REGISTER_MORE_KERNEL(kVSigmoid, VSigmoid);
REGISTER_MORE_KERNEL(vtanh, VTanh); REGISTER_MORE_KERNEL(kVTanh, VTanh);
REGISTER_MORE_KERNEL(lstmctht, LSTMCtHt); REGISTER_MORE_KERNEL(kLSTMCtHt, LSTMCtHt);
REGISTER_MORE_KERNEL(lstmc1h1, LSTMC1H1); REGISTER_MORE_KERNEL(kLSTMC1H1, LSTMC1H1);
REGISTER_MORE_KERNEL(gruh1, GRUH1); REGISTER_MORE_KERNEL(kGRUH1, GRUH1);
REGISTER_MORE_KERNEL(gruhtpart1, GRUHtPart1); REGISTER_MORE_KERNEL(kGRUHtPart1, GRUHtPart1);
REGISTER_MORE_KERNEL(gruhtpart2, GRUHtPart2); REGISTER_MORE_KERNEL(kGRUHtPart2, GRUHtPart2);
#undef REGISTER_MORE_KERNEL #undef REGISTER_MORE_KERNEL
...@@ -3,9 +3,9 @@ cc_library(jit_kernel_mkl SRCS mkl.cc DEPS jit_kernel_base dynload_mklml) ...@@ -3,9 +3,9 @@ cc_library(jit_kernel_mkl SRCS mkl.cc DEPS jit_kernel_base dynload_mklml)
set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} dynload_mklml jit_kernel_mkl PARENT_SCOPE) set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} dynload_mklml jit_kernel_mkl PARENT_SCOPE)
# use mkl kernels by name and type # use mkl kernels by name and type
USE_JITKERNEL_MORE(vmul, mkl) USE_JITKERNEL_MORE(kVMul, mkl)
USE_JITKERNEL_MORE(vadd, mkl) USE_JITKERNEL_MORE(kVAdd, mkl)
USE_JITKERNEL_MORE(vscal, mkl) USE_JITKERNEL_MORE(kVScal, mkl)
USE_JITKERNEL_MORE(vexp, mkl) USE_JITKERNEL_MORE(kVExp, mkl)
USE_JITKERNEL_MORE(vsigmoid, mkl) USE_JITKERNEL_MORE(kVSigmoid, mkl)
USE_JITKERNEL_MORE(vtanh, mkl) USE_JITKERNEL_MORE(kVTanh, mkl)
...@@ -129,11 +129,11 @@ namespace mkl = paddle::operators::jit::more::mkl; ...@@ -129,11 +129,11 @@ namespace mkl = paddle::operators::jit::more::mkl;
REGISTER_JITKERNEL_MORE(key, mkl, mkl::func##Kernel<float>, \ REGISTER_JITKERNEL_MORE(key, mkl, mkl::func##Kernel<float>, \
mkl::func##Kernel<double>) mkl::func##Kernel<double>)
REGISTER_MKL_KERNEL(vmul, VMul); REGISTER_MKL_KERNEL(kVMul, VMul);
REGISTER_MKL_KERNEL(vadd, VAdd); REGISTER_MKL_KERNEL(kVAdd, VAdd);
REGISTER_MKL_KERNEL(vscal, VScal); REGISTER_MKL_KERNEL(kVScal, VScal);
REGISTER_MKL_KERNEL(vexp, VExp); REGISTER_MKL_KERNEL(kVExp, VExp);
REGISTER_MKL_KERNEL(vsigmoid, VSigmoid); REGISTER_MKL_KERNEL(kVSigmoid, VSigmoid);
REGISTER_MKL_KERNEL(vtanh, VTanh); REGISTER_MKL_KERNEL(kVTanh, VTanh);
#undef REGISTER_MKL_KERNEL #undef REGISTER_MKL_KERNEL
...@@ -7,22 +7,22 @@ function(USE_JITKERNEL_REFER TARGET) ...@@ -7,22 +7,22 @@ function(USE_JITKERNEL_REFER TARGET)
endfunction() endfunction()
# use refer kernel by name # use refer kernel by name
USE_JITKERNEL_REFER(vmul) USE_JITKERNEL_REFER(kVMul)
USE_JITKERNEL_REFER(vadd) USE_JITKERNEL_REFER(kVAdd)
USE_JITKERNEL_REFER(vaddrelu) USE_JITKERNEL_REFER(kVAddRelu)
USE_JITKERNEL_REFER(vsub) USE_JITKERNEL_REFER(kVSub)
USE_JITKERNEL_REFER(vscal) USE_JITKERNEL_REFER(kVScal)
USE_JITKERNEL_REFER(vaddbias) USE_JITKERNEL_REFER(kVAddBias)
USE_JITKERNEL_REFER(vrelu) USE_JITKERNEL_REFER(kVRelu)
USE_JITKERNEL_REFER(videntity) USE_JITKERNEL_REFER(kVIdentity)
USE_JITKERNEL_REFER(vexp) USE_JITKERNEL_REFER(kVExp)
USE_JITKERNEL_REFER(vsigmoid) USE_JITKERNEL_REFER(kVSigmoid)
USE_JITKERNEL_REFER(vtanh) USE_JITKERNEL_REFER(kVTanh)
USE_JITKERNEL_REFER(lstmctht) USE_JITKERNEL_REFER(kLSTMCtHt)
USE_JITKERNEL_REFER(lstmc1h1) USE_JITKERNEL_REFER(kLSTMC1H1)
USE_JITKERNEL_REFER(gruh1) USE_JITKERNEL_REFER(kGRUH1)
USE_JITKERNEL_REFER(gruhtpart1) USE_JITKERNEL_REFER(kGRUHtPart1)
USE_JITKERNEL_REFER(gruhtpart2) USE_JITKERNEL_REFER(kGRUHtPart2)
USE_JITKERNEL_REFER(crfdecoding) USE_JITKERNEL_REFER(kCRFDecoding)
USE_JITKERNEL_REFER(layernorm) USE_JITKERNEL_REFER(kLayerNorm)
USE_JITKERNEL_REFER(nchw16cmulnc) USE_JITKERNEL_REFER(kNCHW16CMulNC)
...@@ -21,30 +21,30 @@ namespace refer = paddle::operators::jit::refer; ...@@ -21,30 +21,30 @@ namespace refer = paddle::operators::jit::refer;
REGISTER_JITKERNEL_REFER(key, refer::func##Kernel<float>, \ REGISTER_JITKERNEL_REFER(key, refer::func##Kernel<float>, \
refer::func##Kernel<double>) refer::func##Kernel<double>)
REGISTER_REFER_KERNEL(vmul, VMul); REGISTER_REFER_KERNEL(kVMul, VMul);
REGISTER_REFER_KERNEL(vadd, VAdd); REGISTER_REFER_KERNEL(kVAdd, VAdd);
REGISTER_REFER_KERNEL(vaddrelu, VAddRelu); REGISTER_REFER_KERNEL(kVAddRelu, VAddRelu);
REGISTER_REFER_KERNEL(vsub, VSub); REGISTER_REFER_KERNEL(kVSub, VSub);
REGISTER_REFER_KERNEL(vscal, VScal); REGISTER_REFER_KERNEL(kVScal, VScal);
REGISTER_REFER_KERNEL(vaddbias, VAddBias); REGISTER_REFER_KERNEL(kVAddBias, VAddBias);
REGISTER_REFER_KERNEL(vrelu, VRelu); REGISTER_REFER_KERNEL(kVRelu, VRelu);
REGISTER_REFER_KERNEL(videntity, VIdentity); REGISTER_REFER_KERNEL(kVIdentity, VIdentity);
REGISTER_REFER_KERNEL(vexp, VExp); REGISTER_REFER_KERNEL(kVExp, VExp);
REGISTER_REFER_KERNEL(vsigmoid, VSigmoid); REGISTER_REFER_KERNEL(kVSigmoid, VSigmoid);
REGISTER_REFER_KERNEL(vtanh, VTanh); REGISTER_REFER_KERNEL(kVTanh, VTanh);
REGISTER_REFER_KERNEL(lstmctht, LSTMCtHt); REGISTER_REFER_KERNEL(kLSTMCtHt, LSTMCtHt);
REGISTER_REFER_KERNEL(lstmc1h1, LSTMC1H1); REGISTER_REFER_KERNEL(kLSTMC1H1, LSTMC1H1);
REGISTER_REFER_KERNEL(gruh1, GRUH1); REGISTER_REFER_KERNEL(kGRUH1, GRUH1);
REGISTER_REFER_KERNEL(gruhtpart1, GRUHtPart1); REGISTER_REFER_KERNEL(kGRUHtPart1, GRUHtPart1);
REGISTER_REFER_KERNEL(gruhtpart2, GRUHtPart2); REGISTER_REFER_KERNEL(kGRUHtPart2, GRUHtPart2);
REGISTER_REFER_KERNEL(crfdecoding, CRFDecoding); REGISTER_REFER_KERNEL(kCRFDecoding, CRFDecoding);
REGISTER_REFER_KERNEL(layernorm, LayerNorm); REGISTER_REFER_KERNEL(kLayerNorm, LayerNorm);
REGISTER_REFER_KERNEL(nchw16cmulnc, NCHW16CMulNC); REGISTER_REFER_KERNEL(kNCHW16CMulNC, NCHW16CMulNC);
#undef REGISTER_REFER_KERNEL #undef REGISTER_REFER_KERNEL
...@@ -115,13 +115,13 @@ void VTanh(const T* x, T* y, int n) { ...@@ -115,13 +115,13 @@ void VTanh(const T* x, T* y, int n) {
template <typename T> template <typename T>
void (*getActFunc(KernelType type))(const T*, T*, int) { // NOLINT void (*getActFunc(KernelType type))(const T*, T*, int) { // NOLINT
if (type == vsigmoid) { if (type == kVSigmoid) {
return VSigmoid<T>; return VSigmoid<T>;
} else if (type == vrelu) { } else if (type == kVRelu) {
return VRelu<T>; return VRelu<T>;
} else if (type == vtanh) { } else if (type == kVTanh) {
return VTanh<T>; return VTanh<T>;
} else if (type == videntity) { } else if (type == kVIdentity) {
return VIdentity<T>; return VIdentity<T>;
} }
PADDLE_THROW("Not support type: %s", type); PADDLE_THROW("Not support type: %s", type);
......
...@@ -469,111 +469,111 @@ void TestNCHW16CMulNCKernel() { ...@@ -469,111 +469,111 @@ void TestNCHW16CMulNCKernel() {
} }
// XYZNTuple // XYZNTuple
TEST(JITKernel, vmul) { TEST(JITKernel, kVMul) {
namespace jit = paddle::operators::jit; namespace jit = paddle::operators::jit;
TestXYZNKernel<jit::vmul, float, paddle::platform::CPUPlace>(); TestXYZNKernel<jit::kVMul, float, paddle::platform::CPUPlace>();
TestXYZNKernel<jit::vmul, double, paddle::platform::CPUPlace>(); TestXYZNKernel<jit::kVMul, double, paddle::platform::CPUPlace>();
} }
TEST(JITKernel, vadd) { TEST(JITKernel, kVAdd) {
namespace jit = paddle::operators::jit; namespace jit = paddle::operators::jit;
TestXYZNKernel<jit::vadd, float, paddle::platform::CPUPlace>(); TestXYZNKernel<jit::kVAdd, float, paddle::platform::CPUPlace>();
TestXYZNKernel<jit::vadd, double, paddle::platform::CPUPlace>(); TestXYZNKernel<jit::kVAdd, double, paddle::platform::CPUPlace>();
} }
TEST(JITKernel, vaddrelu) { TEST(JITKernel, kVAddRelu) {
namespace jit = paddle::operators::jit; namespace jit = paddle::operators::jit;
TestXYZNKernel<jit::vaddrelu, float, paddle::platform::CPUPlace>(); TestXYZNKernel<jit::kVAddRelu, float, paddle::platform::CPUPlace>();
TestXYZNKernel<jit::vaddrelu, double, paddle::platform::CPUPlace>(); TestXYZNKernel<jit::kVAddRelu, double, paddle::platform::CPUPlace>();
} }
TEST(JITKernel, vsub) { TEST(JITKernel, kVSub) {
namespace jit = paddle::operators::jit; namespace jit = paddle::operators::jit;
TestXYZNKernel<jit::vsub, float, paddle::platform::CPUPlace>(); TestXYZNKernel<jit::kVSub, float, paddle::platform::CPUPlace>();
TestXYZNKernel<jit::vsub, double, paddle::platform::CPUPlace>(); TestXYZNKernel<jit::kVSub, double, paddle::platform::CPUPlace>();
} }
// AXYNTuples // AXYNTuples
TEST(JITKernel, vscal) { TEST(JITKernel, kVScal) {
namespace jit = paddle::operators::jit; namespace jit = paddle::operators::jit;
TestAXYNKernel<jit::vscal, float, paddle::platform::CPUPlace>(); TestAXYNKernel<jit::kVScal, float, paddle::platform::CPUPlace>();
TestAXYNKernel<jit::vscal, double, paddle::platform::CPUPlace>(); TestAXYNKernel<jit::kVScal, double, paddle::platform::CPUPlace>();
} }
TEST(JITKernel, vaddbias) { TEST(JITKernel, kVAddBias) {
namespace jit = paddle::operators::jit; namespace jit = paddle::operators::jit;
TestAXYNKernel<jit::vaddbias, float, paddle::platform::CPUPlace>(); TestAXYNKernel<jit::kVAddBias, float, paddle::platform::CPUPlace>();
TestAXYNKernel<jit::vaddbias, double, paddle::platform::CPUPlace>(); TestAXYNKernel<jit::kVAddBias, double, paddle::platform::CPUPlace>();
} }
// XYNTuples // XYNTuples
TEST(JITKernel, vrelu) { TEST(JITKernel, kVRelu) {
namespace jit = paddle::operators::jit; namespace jit = paddle::operators::jit;
TestXYNKernel<jit::vrelu, float, paddle::platform::CPUPlace>(); TestXYNKernel<jit::kVRelu, float, paddle::platform::CPUPlace>();
TestXYNKernel<jit::vrelu, double, paddle::platform::CPUPlace>(); TestXYNKernel<jit::kVRelu, double, paddle::platform::CPUPlace>();
} }
TEST(JITKernel, videntity) { TEST(JITKernel, kVIdentity) {
namespace jit = paddle::operators::jit; namespace jit = paddle::operators::jit;
TestXYNKernel<jit::videntity, float, paddle::platform::CPUPlace>(); TestXYNKernel<jit::kVIdentity, float, paddle::platform::CPUPlace>();
TestXYNKernel<jit::videntity, double, paddle::platform::CPUPlace>(); TestXYNKernel<jit::kVIdentity, double, paddle::platform::CPUPlace>();
} }
TEST(JITKernel, vexp) { TEST(JITKernel, kVExp) {
namespace jit = paddle::operators::jit; namespace jit = paddle::operators::jit;
TestXYNKernel<jit::vexp, float, paddle::platform::CPUPlace>(); TestXYNKernel<jit::kVExp, float, paddle::platform::CPUPlace>();
TestXYNKernel<jit::vexp, double, paddle::platform::CPUPlace>(); TestXYNKernel<jit::kVExp, double, paddle::platform::CPUPlace>();
} }
TEST(JITKernel, vsigmoid) { TEST(JITKernel, kVSigmoid) {
namespace jit = paddle::operators::jit; namespace jit = paddle::operators::jit;
TestXYNKernel<jit::vsigmoid, float, paddle::platform::CPUPlace>(); TestXYNKernel<jit::kVSigmoid, float, paddle::platform::CPUPlace>();
TestXYNKernel<jit::vsigmoid, double, paddle::platform::CPUPlace>(); TestXYNKernel<jit::kVSigmoid, double, paddle::platform::CPUPlace>();
} }
TEST(JITKernel, vtanh) { TEST(JITKernel, kVTanh) {
namespace jit = paddle::operators::jit; namespace jit = paddle::operators::jit;
TestXYNKernel<jit::vtanh, float, paddle::platform::CPUPlace>(); TestXYNKernel<jit::kVTanh, float, paddle::platform::CPUPlace>();
TestXYNKernel<jit::vtanh, double, paddle::platform::CPUPlace>(); TestXYNKernel<jit::kVTanh, double, paddle::platform::CPUPlace>();
} }
// LSTM // LSTM
TEST(JITKernel, lstmctht) { TEST(JITKernel, kLSTMCtHt) {
namespace jit = paddle::operators::jit; namespace jit = paddle::operators::jit;
TestLSTMKernel<jit::lstmctht, float, paddle::platform::CPUPlace>(); TestLSTMKernel<jit::kLSTMCtHt, float, paddle::platform::CPUPlace>();
TestLSTMKernel<jit::lstmctht, double, paddle::platform::CPUPlace>(); TestLSTMKernel<jit::kLSTMCtHt, double, paddle::platform::CPUPlace>();
} }
TEST(JITKernel, lstmc1h1) { TEST(JITKernel, kLSTMC1H1) {
namespace jit = paddle::operators::jit; namespace jit = paddle::operators::jit;
TestLSTMKernel<jit::lstmc1h1, float, paddle::platform::CPUPlace>(); TestLSTMKernel<jit::kLSTMC1H1, float, paddle::platform::CPUPlace>();
TestLSTMKernel<jit::lstmc1h1, double, paddle::platform::CPUPlace>(); TestLSTMKernel<jit::kLSTMC1H1, double, paddle::platform::CPUPlace>();
} }
// GRU // GRU
TEST(JITKernel, gruh1) { TEST(JITKernel, kGRUH1) {
namespace jit = paddle::operators::jit; namespace jit = paddle::operators::jit;
TestGRUKernel<jit::gruh1, float, paddle::platform::CPUPlace>(); TestGRUKernel<jit::kGRUH1, float, paddle::platform::CPUPlace>();
TestGRUKernel<jit::gruh1, double, paddle::platform::CPUPlace>(); TestGRUKernel<jit::kGRUH1, double, paddle::platform::CPUPlace>();
} }
TEST(JITKernel, gruhtpart1) { TEST(JITKernel, kGRUHtPart1) {
namespace jit = paddle::operators::jit; namespace jit = paddle::operators::jit;
TestGRUKernel<jit::gruhtpart1, float, paddle::platform::CPUPlace>(); TestGRUKernel<jit::kGRUHtPart1, float, paddle::platform::CPUPlace>();
TestGRUKernel<jit::gruhtpart1, double, paddle::platform::CPUPlace>(); TestGRUKernel<jit::kGRUHtPart1, double, paddle::platform::CPUPlace>();
} }
TEST(JITKernel, gruhtpart2) { TEST(JITKernel, kGRUHtPart2) {
namespace jit = paddle::operators::jit; namespace jit = paddle::operators::jit;
TestGRUKernel<jit::gruhtpart2, float, paddle::platform::CPUPlace>(); TestGRUKernel<jit::kGRUHtPart2, float, paddle::platform::CPUPlace>();
TestGRUKernel<jit::gruhtpart2, double, paddle::platform::CPUPlace>(); TestGRUKernel<jit::kGRUHtPart2, double, paddle::platform::CPUPlace>();
} }
TEST(JITKernel, nchw16cmulnc) { TEST(JITKernel, kNCHW16CMulNC) {
namespace jit = paddle::operators::jit; namespace jit = paddle::operators::jit;
TestNCHW16CMulNCKernel<jit::nchw16cmulnc, float, TestNCHW16CMulNCKernel<jit::kNCHW16CMulNC, float,
paddle::platform::CPUPlace>(); paddle::platform::CPUPlace>();
TestNCHW16CMulNCKernel<jit::nchw16cmulnc, double, TestNCHW16CMulNCKernel<jit::kNCHW16CMulNC, double,
paddle::platform::CPUPlace>(); paddle::platform::CPUPlace>();
} }
......
...@@ -230,7 +230,7 @@ class LayerNormKernel : public framework::OpKernel<T> { ...@@ -230,7 +230,7 @@ class LayerNormKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(bias->numel(), right); PADDLE_ENFORCE_EQ(bias->numel(), right);
auto ker = auto ker =
jit::Get<jit::layernorm, jit::LayerNormTuples<T>, platform::CPUPlace>( jit::Get<jit::kLayerNorm, jit::LayerNormTuples<T>, platform::CPUPlace>(
right); right);
ker(x.data<T>(), out.data<T>(), mean->data<T>(), var->data<T>(), ker(x.data<T>(), out.data<T>(), mean->data<T>(), var->data<T>(),
scale->data<T>(), bias->data<T>(), static_cast<int>(left), scale->data<T>(), bias->data<T>(), static_cast<int>(left),
......
...@@ -31,14 +31,14 @@ inline void FCCompute(const BlasT<DeviceContext, T>& blas, const int M, ...@@ -31,14 +31,14 @@ inline void FCCompute(const BlasT<DeviceContext, T>& blas, const int M,
} }
if (relu) { if (relu) {
auto compute = auto compute =
jit::Get<jit::vaddrelu, jit::XYZNTuples<T>, platform::CPUPlace>(N); jit::Get<jit::kVAddRelu, jit::XYZNTuples<T>, platform::CPUPlace>(N);
for (int i = 0; i < M; i++) { for (int i = 0; i < M; i++) {
T* dst = Y + i * N; T* dst = Y + i * N;
compute(B, dst, dst, N); compute(B, dst, dst, N);
} }
} else { } else {
auto compute = auto compute =
jit::Get<jit::vadd, jit::XYZNTuples<T>, platform::CPUPlace>(N); jit::Get<jit::kVAdd, jit::XYZNTuples<T>, platform::CPUPlace>(N);
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
#pragma omp parallel for #pragma omp parallel for
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册