提交 4a93db92 编写于 作者: T tensor-tang

remove jit namespace

test=develop
上级 8cda28f3
...@@ -231,10 +231,10 @@ use lstm_x_t as input and compute as standard LSTM. ...@@ -231,10 +231,10 @@ use lstm_x_t as input and compute as standard LSTM.
template <typename T> template <typename T>
inline void bias_relu(const int n, const T* x, const T* bias, T* y) { inline void bias_relu(const int n, const T* x, const T* bias, T* y) {
if (bias) { if (bias) {
math::vec_add_bias<T, platform::jit::avx>(n, *bias, x, y); math::vec_add_bias<T, platform::avx>(n, *bias, x, y);
math::vec_relu<T, platform::jit::avx>(n, y, y); math::vec_relu<T, platform::avx>(n, y, y);
} else { } else {
math::vec_relu<T, platform::jit::avx>(n, x, y); math::vec_relu<T, platform::avx>(n, x, y);
} }
} }
...@@ -245,7 +245,7 @@ inline void vec_softmax(const int n, const T* x, T* y) { ...@@ -245,7 +245,7 @@ inline void vec_softmax(const int n, const T* x, T* y) {
for (int i = 1; i < n; ++i) { for (int i = 1; i < n; ++i) {
scalar = scalar < x[i] ? x[i] : scalar; scalar = scalar < x[i] ? x[i] : scalar;
} }
math::vec_add_bias<T, platform::jit::avx>(n, -scalar, x, y); // sub math::vec_add_bias<T, platform::avx>(n, -scalar, x, y); // sub
math::vec_exp<T>(n, y, y); // exp math::vec_exp<T>(n, y, y); // exp
// sum // sum
scalar = T(0); scalar = T(0);
...@@ -302,13 +302,13 @@ class AttentionLSTMKernel : public framework::OpKernel<T> { ...@@ -302,13 +302,13 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
auto& act_gate_str = ctx.Attr<std::string>("gate_activation"); auto& act_gate_str = ctx.Attr<std::string>("gate_activation");
auto& act_cell_str = ctx.Attr<std::string>("cell_activation"); auto& act_cell_str = ctx.Attr<std::string>("cell_activation");
auto& act_cand_str = ctx.Attr<std::string>("candidate_activation"); auto& act_cand_str = ctx.Attr<std::string>("candidate_activation");
if (platform::jit::MayIUse(platform::jit::avx)) { if (platform::MayIUse(platform::avx)) {
math::VecActivations<T, platform::jit::avx> act_functor; math::VecActivations<T, platform::avx> act_functor;
act_gate = act_functor(act_gate_str); act_gate = act_functor(act_gate_str);
act_cell = act_functor(act_cell_str); act_cell = act_functor(act_cell_str);
act_cand = act_functor(act_cand_str); act_cand = act_functor(act_cand_str);
} else { } else {
math::VecActivations<T, platform::jit::isa_any> act_functor; math::VecActivations<T, platform::isa_any> act_functor;
act_gate = act_functor(act_gate_str); act_gate = act_functor(act_gate_str);
act_cell = act_functor(act_cell_str); act_cell = act_functor(act_cell_str);
act_cand = act_functor(act_cand_str); act_cand = act_functor(act_cand_str);
......
...@@ -217,13 +217,13 @@ class FusedEmbeddingFCLSTMKernel : public framework::OpKernel<T> { ...@@ -217,13 +217,13 @@ class FusedEmbeddingFCLSTMKernel : public framework::OpKernel<T> {
auto& act_gate_str = ctx.Attr<std::string>("gate_activation"); \ auto& act_gate_str = ctx.Attr<std::string>("gate_activation"); \
auto& act_cell_str = ctx.Attr<std::string>("cell_activation"); \ auto& act_cell_str = ctx.Attr<std::string>("cell_activation"); \
auto& act_cand_str = ctx.Attr<std::string>("candidate_activation"); \ auto& act_cand_str = ctx.Attr<std::string>("candidate_activation"); \
if (platform::jit::MayIUse(platform::jit::avx)) { \ if (platform::MayIUse(platform::avx)) { \
math::VecActivations<T, platform::jit::avx> act_functor; \ math::VecActivations<T, platform::avx> act_functor; \
act_gate = act_functor(act_gate_str); \ act_gate = act_functor(act_gate_str); \
act_cell = act_functor(act_cell_str); \ act_cell = act_functor(act_cell_str); \
act_cand = act_functor(act_cand_str); \ act_cand = act_functor(act_cand_str); \
} else { \ } else { \
math::VecActivations<T, platform::jit::isa_any> act_functor; \ math::VecActivations<T, platform::isa_any> act_functor; \
act_gate = act_functor(act_gate_str); \ act_gate = act_functor(act_gate_str); \
act_cell = act_functor(act_cell_str); \ act_cell = act_functor(act_cell_str); \
act_cand = act_functor(act_cand_str); \ act_cand = act_functor(act_cand_str); \
......
...@@ -151,11 +151,11 @@ class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel<T> { ...@@ -151,11 +151,11 @@ class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel<T> {
std::function<void(const int, const T*, T*)> fc_act; std::function<void(const int, const T*, T*)> fc_act;
auto& fc_act_str = ctx.Attr<std::string>("fc_activation"); auto& fc_act_str = ctx.Attr<std::string>("fc_activation");
if (platform::jit::MayIUse(platform::jit::avx)) { if (platform::MayIUse(platform::avx)) {
math::VecActivations<T, platform::jit::avx> act_functor; math::VecActivations<T, platform::avx> act_functor;
fc_act = act_functor(fc_act_str); fc_act = act_functor(fc_act_str);
} else { } else {
math::VecActivations<T, platform::jit::isa_any> act_functor; math::VecActivations<T, platform::isa_any> act_functor;
fc_act = act_functor(fc_act_str); fc_act = act_functor(fc_act_str);
} }
......
...@@ -77,7 +77,7 @@ inline void vec_scal<double>(const int n, const double a, double* x) { ...@@ -77,7 +77,7 @@ inline void vec_scal<double>(const int n, const double a, double* x) {
#endif #endif
// MKL scal only support inplace, choose this if src and dst are not equal // MKL scal only support inplace, choose this if src and dst are not equal
template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any> template <typename T, platform::cpu_isa_t isa = platform::isa_any>
inline void vec_scal(const int n, const T a, const T* x, T* y) { inline void vec_scal(const int n, const T a, const T* x, T* y) {
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
y[i] = a * x[i]; y[i] = a * x[i];
...@@ -85,12 +85,12 @@ inline void vec_scal(const int n, const T a, const T* x, T* y) { ...@@ -85,12 +85,12 @@ inline void vec_scal(const int n, const T a, const T* x, T* y) {
} }
template <> template <>
inline void vec_scal<float, platform::jit::avx>(const int n, const float a, inline void vec_scal<float, platform::avx>(const int n, const float a,
const float* x, float* y) { const float* x, float* y) {
#ifdef __AVX__ #ifdef __AVX__
constexpr int block = YMM_FLOAT_BLOCK; constexpr int block = YMM_FLOAT_BLOCK;
if (n < block) { if (n < block) {
vec_scal<float, platform::jit::isa_any>(n, a, x, y); vec_scal<float, platform::isa_any>(n, a, x, y);
return; return;
} }
const int rest = n % block; const int rest = n % block;
...@@ -114,24 +114,24 @@ inline void vec_scal<float, platform::jit::avx>(const int n, const float a, ...@@ -114,24 +114,24 @@ inline void vec_scal<float, platform::jit::avx>(const int n, const float a,
y[i] = a * x[i]; y[i] = a * x[i];
} }
#else #else
vec_scal<float, platform::jit::isa_any>(n, a, x, y); vec_scal<float, platform::isa_any>(n, a, x, y);
#endif #endif
} }
template <> template <>
inline void vec_scal<float, platform::jit::avx2>(const int n, const float a, inline void vec_scal<float, platform::avx2>(const int n, const float a,
const float* x, float* y) { const float* x, float* y) {
vec_scal<float, platform::jit::avx>(n, a, x, y); vec_scal<float, platform::avx>(n, a, x, y);
} }
template <> template <>
inline void vec_scal<float, platform::jit::avx512f>(const int n, const float a, inline void vec_scal<float, platform::avx512f>(const int n, const float a,
const float* x, float* y) { const float* x, float* y) {
// TODO(TJ): enable me // TODO(TJ): enable me
vec_scal<float, platform::jit::avx2>(n, a, x, y); vec_scal<float, platform::avx2>(n, a, x, y);
} }
template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any> template <typename T, platform::cpu_isa_t isa = platform::isa_any>
inline void vec_bias_sub(const int n, const T a, const T* x, T* y) { inline void vec_bias_sub(const int n, const T a, const T* x, T* y) {
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
y[i] = a - x[i]; y[i] = a - x[i];
...@@ -139,12 +139,12 @@ inline void vec_bias_sub(const int n, const T a, const T* x, T* y) { ...@@ -139,12 +139,12 @@ inline void vec_bias_sub(const int n, const T a, const T* x, T* y) {
} }
template <> template <>
inline void vec_bias_sub<float, platform::jit::avx>(const int n, const float a, inline void vec_bias_sub<float, platform::avx>(const int n, const float a,
const float* x, float* y) { const float* x, float* y) {
#ifdef __AVX__ #ifdef __AVX__
constexpr int block = YMM_FLOAT_BLOCK; constexpr int block = YMM_FLOAT_BLOCK;
if (n < block) { if (n < block) {
vec_bias_sub<float, platform::jit::isa_any>(n, a, x, y); vec_bias_sub<float, platform::isa_any>(n, a, x, y);
return; return;
} }
const int rest = n % block; const int rest = n % block;
...@@ -168,27 +168,25 @@ inline void vec_bias_sub<float, platform::jit::avx>(const int n, const float a, ...@@ -168,27 +168,25 @@ inline void vec_bias_sub<float, platform::jit::avx>(const int n, const float a,
y[i] = a - x[i]; y[i] = a - x[i];
} }
#else #else
vec_bias_sub<float, platform::jit::isa_any>(n, a, x, y); vec_bias_sub<float, platform::isa_any>(n, a, x, y);
#endif #endif
} }
template <> template <>
inline void vec_bias_sub<float, platform::jit::avx2>(const int n, const float a, inline void vec_bias_sub<float, platform::avx2>(const int n, const float a,
const float* x, float* y) { const float* x, float* y) {
vec_bias_sub<float, platform::jit::avx>(n, a, x, y); vec_bias_sub<float, platform::avx>(n, a, x, y);
} }
template <> template <>
inline void vec_bias_sub<float, platform::jit::avx512f>(const int n, inline void vec_bias_sub<float, platform::avx512f>(const int n, const float a,
const float a, const float* x, float* y) {
const float* x,
float* y) {
// TODO(TJ): enable me // TODO(TJ): enable me
vec_bias_sub<float, platform::jit::avx2>(n, a, x, y); vec_bias_sub<float, platform::avx2>(n, a, x, y);
} }
// out = x*y + (1-x)*z // out = x*y + (1-x)*z
template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any> template <typename T, platform::cpu_isa_t isa = platform::isa_any>
inline void vec_cross(const int n, const T* x, const T* y, const T* z, T* out) { inline void vec_cross(const int n, const T* x, const T* y, const T* z, T* out) {
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
out[i] = x[i] * y[i] + (static_cast<T>(1) - x[i]) * z[i]; out[i] = x[i] * y[i] + (static_cast<T>(1) - x[i]) * z[i];
...@@ -196,13 +194,13 @@ inline void vec_cross(const int n, const T* x, const T* y, const T* z, T* out) { ...@@ -196,13 +194,13 @@ inline void vec_cross(const int n, const T* x, const T* y, const T* z, T* out) {
} }
template <> template <>
inline void vec_cross<float, platform::jit::avx>(const int n, const float* x, inline void vec_cross<float, platform::avx>(const int n, const float* x,
const float* y, const float* z, const float* y, const float* z,
float* out) { float* out) {
#ifdef __AVX__ #ifdef __AVX__
constexpr int block = YMM_FLOAT_BLOCK; constexpr int block = YMM_FLOAT_BLOCK;
if (n < block) { if (n < block) {
vec_cross<float, platform::jit::isa_any>(n, x, y, z, out); vec_cross<float, platform::isa_any>(n, x, y, z, out);
return; return;
} }
const int rest = n % block; const int rest = n % block;
...@@ -228,25 +226,26 @@ inline void vec_cross<float, platform::jit::avx>(const int n, const float* x, ...@@ -228,25 +226,26 @@ inline void vec_cross<float, platform::jit::avx>(const int n, const float* x,
out[i] = x[i] * y[i] + (1.f - x[i]) * z[i]; out[i] = x[i] * y[i] + (1.f - x[i]) * z[i];
} }
#else #else
vec_cross<float, platform::jit::isa_any>(n, x, y, z, out); vec_cross<float, platform::isa_any>(n, x, y, z, out);
#endif #endif
} }
template <> template <>
inline void vec_cross<float, platform::jit::avx2>(const int n, const float* x, inline void vec_cross<float, platform::avx2>(const int n, const float* x,
const float* y, const float* y, const float* z,
const float* z, float* out) { float* out) {
vec_cross<float, platform::jit::avx>(n, x, y, z, out); vec_cross<float, platform::avx>(n, x, y, z, out);
} }
template <> template <>
inline void vec_cross<float, platform::jit::avx512f>( inline void vec_cross<float, platform::avx512f>(const int n, const float* x,
const int n, const float* x, const float* y, const float* z, float* out) { const float* y, const float* z,
float* out) {
// TODO(TJ): enable me // TODO(TJ): enable me
vec_cross<float, platform::jit::avx>(n, x, y, z, out); vec_cross<float, platform::avx>(n, x, y, z, out);
} }
template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any> template <typename T, platform::cpu_isa_t isa = platform::isa_any>
inline void vec_add_bias(const int n, const T a, const T* x, T* y) { inline void vec_add_bias(const int n, const T a, const T* x, T* y) {
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
y[i] = x[i] + a; y[i] = x[i] + a;
...@@ -254,12 +253,12 @@ inline void vec_add_bias(const int n, const T a, const T* x, T* y) { ...@@ -254,12 +253,12 @@ inline void vec_add_bias(const int n, const T a, const T* x, T* y) {
} }
template <> template <>
inline void vec_add_bias<float, platform::jit::avx>(const int n, const float a, inline void vec_add_bias<float, platform::avx>(const int n, const float a,
const float* x, float* y) { const float* x, float* y) {
#ifdef __AVX__ #ifdef __AVX__
constexpr int block = YMM_FLOAT_BLOCK; constexpr int block = YMM_FLOAT_BLOCK;
if (n < block) { if (n < block) {
vec_add_bias<float, platform::jit::isa_any>(n, a, x, y); vec_add_bias<float, platform::isa_any>(n, a, x, y);
return; return;
} }
const int rest = n % block; const int rest = n % block;
...@@ -283,32 +282,30 @@ inline void vec_add_bias<float, platform::jit::avx>(const int n, const float a, ...@@ -283,32 +282,30 @@ inline void vec_add_bias<float, platform::jit::avx>(const int n, const float a,
y[i] = x[i] + a; y[i] = x[i] + a;
} }
#else #else
vec_add_bias<float, platform::jit::isa_any>(n, a, x, y); vec_add_bias<float, platform::isa_any>(n, a, x, y);
#endif #endif
} }
template <> template <>
inline void vec_add_bias<float, platform::jit::avx2>(const int n, const float a, inline void vec_add_bias<float, platform::avx2>(const int n, const float a,
const float* x, float* y) { const float* x, float* y) {
vec_add_bias<float, platform::jit::avx>(n, a, x, y); vec_add_bias<float, platform::avx>(n, a, x, y);
} }
template <> template <>
inline void vec_add_bias<float, platform::jit::avx512f>(const int n, inline void vec_add_bias<float, platform::avx512f>(const int n, const float a,
const float a, const float* x, float* y) {
const float* x,
float* y) {
// TODO(TJ): enable me // TODO(TJ): enable me
vec_add_bias<float, platform::jit::avx2>(n, a, x, y); vec_add_bias<float, platform::avx2>(n, a, x, y);
} }
template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any> template <typename T, platform::cpu_isa_t isa = platform::isa_any>
inline void vec_identity(const int n, const T* x, T* y) { inline void vec_identity(const int n, const T* x, T* y) {
// do nothing // do nothing
return; return;
} }
template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any> template <typename T, platform::cpu_isa_t isa = platform::isa_any>
inline void vec_sigmoid(const int n, const T* x, T* y) { inline void vec_sigmoid(const int n, const T* x, T* y) {
const T min = SIGMOID_THRESHOLD_MIN; const T min = SIGMOID_THRESHOLD_MIN;
const T max = SIGMOID_THRESHOLD_MAX; const T max = SIGMOID_THRESHOLD_MAX;
...@@ -323,12 +320,12 @@ inline void vec_sigmoid(const int n, const T* x, T* y) { ...@@ -323,12 +320,12 @@ inline void vec_sigmoid(const int n, const T* x, T* y) {
} }
template <> template <>
inline void vec_sigmoid<float, platform::jit::avx>(const int n, const float* x, inline void vec_sigmoid<float, platform::avx>(const int n, const float* x,
float* y) { float* y) {
#ifdef __AVX__ #ifdef __AVX__
constexpr int block = YMM_FLOAT_BLOCK; constexpr int block = YMM_FLOAT_BLOCK;
if (n < block) { if (n < block) {
vec_sigmoid<float, platform::jit::isa_any>(n, x, y); vec_sigmoid<float, platform::isa_any>(n, x, y);
return; return;
} }
const int rest = n % block; const int rest = n % block;
...@@ -377,25 +374,24 @@ inline void vec_sigmoid<float, platform::jit::avx>(const int n, const float* x, ...@@ -377,25 +374,24 @@ inline void vec_sigmoid<float, platform::jit::avx>(const int n, const float* x,
y[i] = 1.f / (1.f + y[i]); y[i] = 1.f / (1.f + y[i]);
} }
#else #else
vec_sigmoid<float, platform::jit::isa_any>(n, x, y); vec_sigmoid<float, platform::isa_any>(n, x, y);
#endif #endif
} }
template <> template <>
inline void vec_sigmoid<float, platform::jit::avx2>(const int n, const float* x, inline void vec_sigmoid<float, platform::avx2>(const int n, const float* x,
float* y) { float* y) {
vec_sigmoid<float, platform::jit::avx>(n, x, y); vec_sigmoid<float, platform::avx>(n, x, y);
} }
template <> template <>
inline void vec_sigmoid<float, platform::jit::avx512f>(const int n, inline void vec_sigmoid<float, platform::avx512f>(const int n, const float* x,
const float* x,
float* y) { float* y) {
// TODO(TJ): enable me // TODO(TJ): enable me
vec_sigmoid<float, platform::jit::avx2>(n, x, y); vec_sigmoid<float, platform::avx2>(n, x, y);
} }
template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any> template <typename T, platform::cpu_isa_t isa = platform::isa_any>
inline void vec_tanh(const int n, const T* x, T* y) { inline void vec_tanh(const int n, const T* x, T* y) {
vec_scal<T, isa>(n, static_cast<T>(2), x, y); vec_scal<T, isa>(n, static_cast<T>(2), x, y);
vec_sigmoid<T, isa>(n, y, y); vec_sigmoid<T, isa>(n, y, y);
...@@ -404,7 +400,7 @@ inline void vec_tanh(const int n, const T* x, T* y) { ...@@ -404,7 +400,7 @@ inline void vec_tanh(const int n, const T* x, T* y) {
} }
// TODO(TJ): make relu clip // TODO(TJ): make relu clip
template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any> template <typename T, platform::cpu_isa_t isa = platform::isa_any>
inline void vec_relu(const int n, const T* x, T* y) { inline void vec_relu(const int n, const T* x, T* y) {
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
y[i] = x[i] > 0 ? x[i] : 0; y[i] = x[i] > 0 ? x[i] : 0;
...@@ -412,12 +408,12 @@ inline void vec_relu(const int n, const T* x, T* y) { ...@@ -412,12 +408,12 @@ inline void vec_relu(const int n, const T* x, T* y) {
} }
template <> template <>
inline void vec_relu<float, platform::jit::avx>(const int n, const float* x, inline void vec_relu<float, platform::avx>(const int n, const float* x,
float* y) { float* y) {
#ifdef __AVX__ #ifdef __AVX__
constexpr int block = YMM_FLOAT_BLOCK; constexpr int block = YMM_FLOAT_BLOCK;
if (n < block * 4) { if (n < block * 4) {
vec_relu<float, platform::jit::isa_any>(n, x, y); vec_relu<float, platform::isa_any>(n, x, y);
return; return;
} }
...@@ -441,26 +437,26 @@ inline void vec_relu<float, platform::jit::avx>(const int n, const float* x, ...@@ -441,26 +437,26 @@ inline void vec_relu<float, platform::jit::avx>(const int n, const float* x,
#undef MOVE_ONE_STEP #undef MOVE_ONE_STEP
#else #else
vec_relu<float, platform::jit::isa_any>(n, x, y); vec_relu<float, platform::isa_any>(n, x, y);
#endif #endif
} }
template <> template <>
inline void vec_relu<float, platform::jit::avx2>(const int n, const float* x, inline void vec_relu<float, platform::avx2>(const int n, const float* x,
float* y) { float* y) {
vec_relu<float, platform::jit::avx>(n, x, y); vec_relu<float, platform::avx>(n, x, y);
} }
template <> template <>
inline void vec_relu<float, platform::jit::avx512f>(const int n, const float* x, inline void vec_relu<float, platform::avx512f>(const int n, const float* x,
float* y) { float* y) {
// TODO(TJ): enable me // TODO(TJ): enable me
vec_relu<float, platform::jit::avx2>(n, x, y); vec_relu<float, platform::avx2>(n, x, y);
} }
// TODO(TJ): optimize double of sigmoid, tanh and relu if necessary // TODO(TJ): optimize double of sigmoid, tanh and relu if necessary
template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any> template <typename T, platform::cpu_isa_t isa = platform::isa_any>
class VecActivations { class VecActivations {
public: public:
std::function<void(const int, const T*, T*)> operator()( std::function<void(const int, const T*, T*)> operator()(
......
...@@ -104,38 +104,42 @@ void TestAndBench(const int n, std::function<void(const int, const T*, T*)> tgt, ...@@ -104,38 +104,42 @@ void TestAndBench(const int n, std::function<void(const int, const T*, T*)> tgt,
} }
TEST(CpuVecTest, sigmoid) { TEST(CpuVecTest, sigmoid) {
namespace jit = paddle::platform::jit; namespace platform = paddle::platform;
using namespace paddle::operators::math; // NOLINT using namespace paddle::operators::math; // NOLINT
for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) { for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) {
TestAndBench<float>(sz, vec_sigmoid<float>, ref_sigmoid<float>); TestAndBench<float>(sz, vec_sigmoid<float>, ref_sigmoid<float>);
TestAndBench<float>(sz, vec_sigmoid<float, jit::avx>, ref_sigmoid<float>); TestAndBench<float>(sz, vec_sigmoid<float, platform::avx>,
TestAndBench<float>(sz, vec_sigmoid<float, jit::avx2>, ref_sigmoid<float>); ref_sigmoid<float>);
TestAndBench<float>(sz, vec_sigmoid<float, jit::avx512f>, TestAndBench<float>(sz, vec_sigmoid<float, platform::avx2>,
ref_sigmoid<float>);
TestAndBench<float>(sz, vec_sigmoid<float, platform::avx512f>,
ref_sigmoid<float>); ref_sigmoid<float>);
} }
TestAndBench<double>(30, vec_sigmoid<double>, ref_sigmoid<double>); TestAndBench<double>(30, vec_sigmoid<double>, ref_sigmoid<double>);
} }
TEST(CpuVecTest, tanh) { TEST(CpuVecTest, tanh) {
namespace jit = paddle::platform::jit; namespace platform = paddle::platform;
using namespace paddle::operators::math; // NOLINT using namespace paddle::operators::math; // NOLINT
for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) { for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) {
TestAndBench<float>(sz, vec_tanh<float>, ref_tanh<float>); TestAndBench<float>(sz, vec_tanh<float>, ref_tanh<float>);
TestAndBench<float>(sz, vec_tanh<float, jit::avx>, ref_tanh<float>); TestAndBench<float>(sz, vec_tanh<float, platform::avx>, ref_tanh<float>);
TestAndBench<float>(sz, vec_tanh<float, jit::avx2>, ref_tanh<float>); TestAndBench<float>(sz, vec_tanh<float, platform::avx2>, ref_tanh<float>);
TestAndBench<float>(sz, vec_tanh<float, jit::avx512f>, ref_tanh<float>); TestAndBench<float>(sz, vec_tanh<float, platform::avx512f>,
ref_tanh<float>);
} }
TestAndBench<double>(30, vec_tanh<double>, ref_tanh<double>); TestAndBench<double>(30, vec_tanh<double>, ref_tanh<double>);
} }
TEST(CpuVecTest, relu) { TEST(CpuVecTest, relu) {
namespace jit = paddle::platform::jit; namespace platform = paddle::platform;
using namespace paddle::operators::math; // NOLINT using namespace paddle::operators::math; // NOLINT
for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) { for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) {
TestAndBench<float>(sz, vec_relu<float>, ref_relu<float>); TestAndBench<float>(sz, vec_relu<float>, ref_relu<float>);
TestAndBench<float>(sz, vec_relu<float, jit::avx>, ref_relu<float>); TestAndBench<float>(sz, vec_relu<float, platform::avx>, ref_relu<float>);
TestAndBench<float>(sz, vec_relu<float, jit::avx2>, ref_relu<float>); TestAndBench<float>(sz, vec_relu<float, platform::avx2>, ref_relu<float>);
TestAndBench<float>(sz, vec_relu<float, jit::avx512f>, ref_relu<float>); TestAndBench<float>(sz, vec_relu<float, platform::avx512f>,
ref_relu<float>);
} }
TestAndBench<double>(30, vec_relu<double>, ref_relu<double>); TestAndBench<double>(30, vec_relu<double>, ref_relu<double>);
} }
...@@ -162,38 +166,40 @@ void TestInplace(const int n, std::function<void(const int, const T*, T*)> tgt, ...@@ -162,38 +166,40 @@ void TestInplace(const int n, std::function<void(const int, const T*, T*)> tgt,
} }
TEST(CpuVecTest, inplace_sigmoid) { TEST(CpuVecTest, inplace_sigmoid) {
namespace jit = paddle::platform::jit; namespace platform = paddle::platform;
using namespace paddle::operators::math; // NOLINT using namespace paddle::operators::math; // NOLINT
for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) { for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) {
TestInplace<float>(sz, vec_sigmoid<float>, ref_sigmoid<float>); TestInplace<float>(sz, vec_sigmoid<float>, ref_sigmoid<float>);
TestInplace<float>(sz, vec_sigmoid<float, jit::avx>, ref_sigmoid<float>); TestInplace<float>(sz, vec_sigmoid<float, platform::avx>,
TestInplace<float>(sz, vec_sigmoid<float, jit::avx2>, ref_sigmoid<float>); ref_sigmoid<float>);
TestInplace<float>(sz, vec_sigmoid<float, jit::avx512f>, TestInplace<float>(sz, vec_sigmoid<float, platform::avx2>,
ref_sigmoid<float>);
TestInplace<float>(sz, vec_sigmoid<float, platform::avx512f>,
ref_sigmoid<float>); ref_sigmoid<float>);
} }
TestInplace<double>(30, vec_sigmoid<double>, ref_sigmoid<double>); TestInplace<double>(30, vec_sigmoid<double>, ref_sigmoid<double>);
} }
TEST(CpuVecTest, inplace_tanh) { TEST(CpuVecTest, inplace_tanh) {
namespace jit = paddle::platform::jit; namespace platform = paddle::platform;
using namespace paddle::operators::math; // NOLINT using namespace paddle::operators::math; // NOLINT
for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) { for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) {
TestInplace<float>(sz, vec_tanh<float>, ref_tanh<float>); TestInplace<float>(sz, vec_tanh<float>, ref_tanh<float>);
TestInplace<float>(sz, vec_tanh<float, jit::avx>, ref_tanh<float>); TestInplace<float>(sz, vec_tanh<float, platform::avx>, ref_tanh<float>);
TestInplace<float>(sz, vec_tanh<float, jit::avx2>, ref_tanh<float>); TestInplace<float>(sz, vec_tanh<float, platform::avx2>, ref_tanh<float>);
TestInplace<float>(sz, vec_tanh<float, jit::avx512f>, ref_tanh<float>); TestInplace<float>(sz, vec_tanh<float, platform::avx512f>, ref_tanh<float>);
} }
TestInplace<double>(30, vec_tanh<double>, ref_tanh<double>); TestInplace<double>(30, vec_tanh<double>, ref_tanh<double>);
} }
TEST(CpuVecTest, inplace_relu) { TEST(CpuVecTest, inplace_relu) {
namespace jit = paddle::platform::jit; namespace platform = paddle::platform;
using namespace paddle::operators::math; // NOLINT using namespace paddle::operators::math; // NOLINT
for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) { for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) {
TestInplace<float>(sz, vec_relu<float>, ref_relu<float>); TestInplace<float>(sz, vec_relu<float>, ref_relu<float>);
TestInplace<float>(sz, vec_relu<float, jit::avx>, ref_relu<float>); TestInplace<float>(sz, vec_relu<float, platform::avx>, ref_relu<float>);
TestInplace<float>(sz, vec_relu<float, jit::avx2>, ref_relu<float>); TestInplace<float>(sz, vec_relu<float, platform::avx2>, ref_relu<float>);
TestInplace<float>(sz, vec_relu<float, jit::avx512f>, ref_relu<float>); TestInplace<float>(sz, vec_relu<float, platform::avx512f>, ref_relu<float>);
} }
TestInplace<double>(30, vec_relu<double>, ref_relu<double>); TestInplace<double>(30, vec_relu<double>, ref_relu<double>);
} }
...@@ -22,7 +22,7 @@ namespace math { ...@@ -22,7 +22,7 @@ namespace math {
namespace jitkernel { namespace jitkernel {
namespace gen { namespace gen {
using namespace platform::jit; // NOLINT using namespace platform; // NOLINT
bool VXXJitCode::init(int d, int scalar_index) { bool VXXJitCode::init(int d, int scalar_index) {
// It's not necessary to use avx512 since it would slow down the frequency // It's not necessary to use avx512 since it would slow down the frequency
......
...@@ -179,7 +179,7 @@ class VActJitCode : public JitCode { ...@@ -179,7 +179,7 @@ class VActJitCode : public JitCode {
template <typename JMM> template <typename JMM>
void exp_jmm(JMM& dst, JMM& src, int src_idx = 11, int fx_idx = 12, // NOLINT void exp_jmm(JMM& dst, JMM& src, int src_idx = 11, int fx_idx = 12, // NOLINT
int fy_idx = 13, int mask_idx = 14, int tmp_idx = 15) { int fy_idx = 13, int mask_idx = 14, int tmp_idx = 15) {
using namespace platform::jit; // NOLINT using namespace platform; // NOLINT
// check all idx can not equal // check all idx can not equal
JMM jmm_src = JMM(src_idx); JMM jmm_src = JMM(src_idx);
JMM jmm_fx = JMM(fx_idx); JMM jmm_fx = JMM(fx_idx);
......
...@@ -36,7 +36,7 @@ void JitCode::preCode() { ...@@ -36,7 +36,7 @@ void JitCode::preCode() {
for (int i = 0; i < num_g_abi_regs; ++i) { for (int i = 0; i < num_g_abi_regs; ++i) {
push(Xbyak::Reg64(g_abi_regs[i])); push(Xbyak::Reg64(g_abi_regs[i]));
} }
if (platform::jit::MayIUse(platform::jit::avx512f)) { if (platform::MayIUse(platform::avx512f)) {
mov(reg_EVEX_max_8b_offt, 2 * EVEX_max_8b_offt); mov(reg_EVEX_max_8b_offt, 2 * EVEX_max_8b_offt);
} }
} }
......
...@@ -21,8 +21,6 @@ namespace operators { ...@@ -21,8 +21,6 @@ namespace operators {
namespace math { namespace math {
namespace jitkernel { namespace jitkernel {
namespace jit = platform::jit;
KernelPool& KernelPool::Instance() { KernelPool& KernelPool::Instance() {
static thread_local KernelPool g_jit_kernels; static thread_local KernelPool g_jit_kernels;
return g_jit_kernels; return g_jit_kernels;
......
...@@ -30,7 +30,6 @@ namespace paddle { ...@@ -30,7 +30,6 @@ namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
namespace jitkernel { namespace jitkernel {
namespace jit = platform::jit;
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
template <typename T> template <typename T>
...@@ -125,7 +124,7 @@ bool VMulKernelImpl<float>::useJIT(int d) { ...@@ -125,7 +124,7 @@ bool VMulKernelImpl<float>::useJIT(int d) {
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
template <> template <>
bool VMulKernelImpl<float>::useMKL(int d) { bool VMulKernelImpl<float>::useMKL(int d) {
return jit::MayIUse(jit::avx512f) && d > 512; return platform::MayIUse(platform::avx512f) && d > 512;
} }
template <> template <>
......
...@@ -25,10 +25,8 @@ namespace operators { ...@@ -25,10 +25,8 @@ namespace operators {
namespace math { namespace math {
namespace jitkernel { namespace jitkernel {
namespace jit = platform::jit;
/* CRF Decode JitKernel */ /* CRF Decode JitKernel */
template <typename T, platform::jit::cpu_isa_t isa, jit_block> template <typename T, platform::cpu_isa_t isa, jit_block>
class CRFDecodeKernelImpl : public CRFDecodeKernel<T> { class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
public: public:
explicit CRFDecodeKernelImpl(int tag_num) : CRFDecodeKernel<T>() { explicit CRFDecodeKernelImpl(int tag_num) : CRFDecodeKernel<T>() {
...@@ -101,7 +99,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> { ...@@ -101,7 +99,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
#define INTRIAVX_FLOAT(block) \ #define INTRIAVX_FLOAT(block) \
template <> \ template <> \
CRFDecodeKernelImpl<float, jit::avx, block>::CRFDecodeKernelImpl( \ CRFDecodeKernelImpl<float, platform::avx, block>::CRFDecodeKernelImpl( \
int tag_num) \ int tag_num) \
: CRFDecodeKernel<float>() { \ : CRFDecodeKernel<float>() { \
this->num_ = tag_num; \ this->num_ = tag_num; \
...@@ -109,7 +107,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> { ...@@ -109,7 +107,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
this->rest_ = this->num_ % YMM_FLOAT_BLOCK; \ this->rest_ = this->num_ % YMM_FLOAT_BLOCK; \
} \ } \
template <> \ template <> \
void CRFDecodeKernelImpl<float, jit::avx, block>::Compute( \ void CRFDecodeKernelImpl<float, platform::avx, block>::Compute( \
const int seq_len, const float* x, const float* w, float* alpha, \ const int seq_len, const float* x, const float* w, float* alpha, \
int* track) const { \ int* track) const { \
INIT_ALPHA(YMM_FLOAT_BLOCK) \ INIT_ALPHA(YMM_FLOAT_BLOCK) \
...@@ -204,7 +202,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> { ...@@ -204,7 +202,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
#define INTRIAVX512_FLOAT(block) \ #define INTRIAVX512_FLOAT(block) \
template <> \ template <> \
CRFDecodeKernelImpl<float, jit::avx512f, block>::CRFDecodeKernelImpl( \ CRFDecodeKernelImpl<float, platform::avx512f, block>::CRFDecodeKernelImpl( \
int tag_num) \ int tag_num) \
: CRFDecodeKernel<float>() { \ : CRFDecodeKernel<float>() { \
this->num_ = tag_num; \ this->num_ = tag_num; \
...@@ -212,7 +210,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> { ...@@ -212,7 +210,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
this->rest_ = this->num_ % ZMM_FLOAT_BLOCK; \ this->rest_ = this->num_ % ZMM_FLOAT_BLOCK; \
} \ } \
template <> \ template <> \
void CRFDecodeKernelImpl<float, jit::avx512f, block>::Compute( \ void CRFDecodeKernelImpl<float, platform::avx512f, block>::Compute( \
const int seq_len, const float* x, const float* w, float* alpha, \ const int seq_len, const float* x, const float* w, float* alpha, \
int* track) const { \ int* track) const { \
INIT_ALPHA(ZMM_FLOAT_BLOCK) \ INIT_ALPHA(ZMM_FLOAT_BLOCK) \
...@@ -270,14 +268,14 @@ INTRIAVX_FLOAT(kEQ16); ...@@ -270,14 +268,14 @@ INTRIAVX_FLOAT(kEQ16);
INTRIAVX_FLOAT(kGT16); INTRIAVX_FLOAT(kGT16);
#endif #endif
#ifdef __AVX2__ #ifdef __AVX2__
INTRIAVX2_FLOAT(jit::avx2, kEQ8); INTRIAVX2_FLOAT(platform::avx2, kEQ8);
INTRIAVX2_FLOAT(jit::avx2, kGT8LT16); INTRIAVX2_FLOAT(platform::avx2, kGT8LT16);
INTRIAVX2_FLOAT(jit::avx2, kEQ16); INTRIAVX2_FLOAT(platform::avx2, kEQ16);
INTRIAVX2_FLOAT(jit::avx2, kGT16); INTRIAVX2_FLOAT(platform::avx2, kGT16);
#endif #endif
#ifdef __AVX512F__ #ifdef __AVX512F__
INTRIAVX2_FLOAT(jit::avx512f, kEQ8); INTRIAVX2_FLOAT(platform::avx512f, kEQ8);
INTRIAVX2_FLOAT(jit::avx512f, kGT8LT16); INTRIAVX2_FLOAT(platform::avx512f, kGT8LT16);
INTRIAVX512_FLOAT(kEQ16); INTRIAVX512_FLOAT(kEQ16);
INTRIAVX512_FLOAT(kGT16); INTRIAVX512_FLOAT(kGT16);
#endif #endif
......
...@@ -29,7 +29,6 @@ namespace paddle { ...@@ -29,7 +29,6 @@ namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
namespace jitkernel { namespace jitkernel {
namespace jit = platform::jit;
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
// try to use MKL to speedup // try to use MKL to speedup
......
...@@ -22,10 +22,8 @@ namespace operators { ...@@ -22,10 +22,8 @@ namespace operators {
namespace math { namespace math {
namespace jitkernel { namespace jitkernel {
namespace jit = platform::jit;
/* Layer Norm JitKernel */ /* Layer Norm JitKernel */
template <typename T, platform::jit::cpu_isa_t isa, jit_block> template <typename T, platform::cpu_isa_t isa, jit_block>
class LayerNormKernelImpl : public LayerNormKernel<T> { class LayerNormKernelImpl : public LayerNormKernel<T> {
public: public:
explicit LayerNormKernelImpl(int right) : LayerNormKernel<T>() { explicit LayerNormKernelImpl(int right) : LayerNormKernel<T>() {
...@@ -90,7 +88,7 @@ class LayerNormKernelImpl : public LayerNormKernel<T> { ...@@ -90,7 +88,7 @@ class LayerNormKernelImpl : public LayerNormKernel<T> {
this->end_ = this->num_ - this->rest_; \ this->end_ = this->num_ - this->rest_; \
} \ } \
template <> \ template <> \
void LayerNormKernelImpl<float, jit::avx, block>::Compute( \ void LayerNormKernelImpl<float, platform::avx, block>::Compute( \
float* x, float* out, float* mean, float* var, const float* scale, \ float* x, float* out, float* mean, float* var, const float* scale, \
const float* bias, int height, const float epsilon) const { \ const float* bias, int height, const float epsilon) const { \
__m256 sum; \ __m256 sum; \
...@@ -219,16 +217,16 @@ class LayerNormKernelImpl : public LayerNormKernel<T> { ...@@ -219,16 +217,16 @@ class LayerNormKernelImpl : public LayerNormKernel<T> {
} }
#ifdef __AVX__ #ifdef __AVX__
INTRIAVX_FLOAT(jit::avx, kEQ8); INTRIAVX_FLOAT(platform::avx, kEQ8);
INTRIAVX_FLOAT(jit::avx, kGT8LT16); INTRIAVX_FLOAT(platform::avx, kGT8LT16);
INTRIAVX_FLOAT(jit::avx, kEQ16); INTRIAVX_FLOAT(platform::avx, kEQ16);
INTRIAVX_FLOAT(jit::avx, kGT16); INTRIAVX_FLOAT(platform::avx, kGT16);
#endif #endif
#ifdef __AVX2__ #ifdef __AVX2__
INTRIAVX_FLOAT(jit::avx2, kEQ8); INTRIAVX_FLOAT(platform::avx2, kEQ8);
INTRIAVX_FLOAT(jit::avx2, kGT8LT16); INTRIAVX_FLOAT(platform::avx2, kGT8LT16);
INTRIAVX_FLOAT(jit::avx2, kEQ16); INTRIAVX_FLOAT(platform::avx2, kEQ16);
INTRIAVX_FLOAT(jit::avx2, kGT16); INTRIAVX_FLOAT(platform::avx2, kGT16);
#endif #endif
#undef INTRIAVX_FLOAT #undef INTRIAVX_FLOAT
......
...@@ -92,7 +92,6 @@ namespace jitkernel { ...@@ -92,7 +92,6 @@ namespace jitkernel {
JITKERNEL_DECLARE, JITKERNEL_FIND_KEY, \ JITKERNEL_DECLARE, JITKERNEL_FIND_KEY, \
JITKERNEL_IMPL) JITKERNEL_IMPL)
namespace jit = platform::jit;
// TODO(TJ): below defines are deprecated, would be remove recently // TODO(TJ): below defines are deprecated, would be remove recently
#define SEARCH_BLOCK(macro_, ker, dtype, isa) \ #define SEARCH_BLOCK(macro_, ker, dtype, isa) \
if (d < YMM_FLOAT_BLOCK) { \ if (d < YMM_FLOAT_BLOCK) { \
...@@ -108,14 +107,14 @@ namespace jit = platform::jit; ...@@ -108,14 +107,14 @@ namespace jit = platform::jit;
} }
#define SEARCH_ISA_BLOCK(macro_, ker, dtype) \ #define SEARCH_ISA_BLOCK(macro_, ker, dtype) \
if (jit::MayIUse(jit::avx512f)) { \ if (platform::MayIUse(platform::avx512f)) { \
SEARCH_BLOCK(macro_, ker, dtype, jit::avx512f); \ SEARCH_BLOCK(macro_, ker, dtype, platform::avx512f); \
} else if (jit::MayIUse(jit::avx2)) { \ } else if (platform::MayIUse(platform::avx2)) { \
SEARCH_BLOCK(macro_, ker, dtype, jit::avx2); \ SEARCH_BLOCK(macro_, ker, dtype, platform::avx2); \
} else if (jit::MayIUse(jit::avx)) { \ } else if (platform::MayIUse(platform::avx)) { \
SEARCH_BLOCK(macro_, ker, dtype, jit::avx); \ SEARCH_BLOCK(macro_, ker, dtype, platform::avx); \
} else { \ } else { \
SEARCH_BLOCK(macro_, ker, dtype, jit::isa_any); \ SEARCH_BLOCK(macro_, ker, dtype, platform::isa_any); \
} }
#define JITKERNEL_KEY(ker_key, dtype_key) \ #define JITKERNEL_KEY(ker_key, dtype_key) \
...@@ -156,10 +155,10 @@ namespace jit = platform::jit; ...@@ -156,10 +155,10 @@ namespace jit = platform::jit;
marco_declare, macro_key, macro_impl) marco_declare, macro_key, macro_impl)
#define FOR_EACH_ISA(macro_, block) \ #define FOR_EACH_ISA(macro_, block) \
macro_(jit::avx512f, block); \ macro_(platform::avx512f, block); \
macro_(jit::avx2, block); \ macro_(platform::avx2, block); \
macro_(jit::avx, block); \ macro_(platform::avx, block); \
macro_(jit::isa_any, block) macro_(platform::isa_any, block)
#define FOR_EACH_BLOCK(macro_, isa) \ #define FOR_EACH_BLOCK(macro_, isa) \
macro_(isa, kLT8); \ macro_(isa, kLT8); \
...@@ -169,10 +168,10 @@ namespace jit = platform::jit; ...@@ -169,10 +168,10 @@ namespace jit = platform::jit;
macro_(isa, kGT16) macro_(isa, kGT16)
#define FOR_EACH_ISA_BLOCK(macro_) \ #define FOR_EACH_ISA_BLOCK(macro_) \
FOR_EACH_BLOCK(macro_, jit::avx512f); \ FOR_EACH_BLOCK(macro_, platform::avx512f); \
FOR_EACH_BLOCK(macro_, jit::avx2); \ FOR_EACH_BLOCK(macro_, platform::avx2); \
FOR_EACH_BLOCK(macro_, jit::avx); \ FOR_EACH_BLOCK(macro_, platform::avx); \
FOR_EACH_BLOCK(macro_, jit::isa_any) FOR_EACH_BLOCK(macro_, platform::isa_any)
} // namespace jitkernel } // namespace jitkernel
} // namespace math } // namespace math
......
...@@ -705,7 +705,7 @@ TEST(JitKernel, pool) { ...@@ -705,7 +705,7 @@ TEST(JitKernel, pool) {
jit::lstm_attr_t attr(frame_size, act_gate, act_cand, act_cell, false); jit::lstm_attr_t attr(frame_size, act_gate, act_cand, act_cell, false);
// empty call it to avoid unknown flag 'use_pinned_memory' on Mac // empty call it to avoid unknown flag 'use_pinned_memory' on Mac
paddle::platform::jit::MayIUse(paddle::platform::jit::avx); paddle::platform::MayIUse(paddle::platform::avx);
const auto& plstm1 = const auto& plstm1 =
jit::KernelPool::Instance() jit::KernelPool::Instance()
.template Get<jit::LSTMKernel<float>, const jit::lstm_attr_t&>(attr); .template Get<jit::LSTMKernel<float>, const jit::lstm_attr_t&>(attr);
......
...@@ -123,7 +123,6 @@ size_t CUDAPinnedMaxChunkSize() { ...@@ -123,7 +123,6 @@ size_t CUDAPinnedMaxChunkSize() {
return CUDAPinnedMaxAllocSize() / 256; return CUDAPinnedMaxAllocSize() / 256;
} }
namespace jit {
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
static Xbyak::util::Cpu cpu; static Xbyak::util::Cpu cpu;
bool MayIUse(const cpu_isa_t cpu_isa) { bool MayIUse(const cpu_isa_t cpu_isa) {
...@@ -165,6 +164,5 @@ bool MayIUse(const cpu_isa_t cpu_isa) { ...@@ -165,6 +164,5 @@ bool MayIUse(const cpu_isa_t cpu_isa) {
} }
#endif #endif
} // namespace jit
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -39,7 +39,6 @@ size_t CUDAPinnedMinChunkSize(); ...@@ -39,7 +39,6 @@ size_t CUDAPinnedMinChunkSize();
//! Get the maximum chunk size for buddy allocator. //! Get the maximum chunk size for buddy allocator.
size_t CUDAPinnedMaxChunkSize(); size_t CUDAPinnedMaxChunkSize();
namespace jit {
typedef enum { typedef enum {
isa_any, isa_any,
sse42, sse42,
...@@ -55,7 +54,5 @@ typedef enum { ...@@ -55,7 +54,5 @@ typedef enum {
// May I use some instruction // May I use some instruction
bool MayIUse(const cpu_isa_t cpu_isa); bool MayIUse(const cpu_isa_t cpu_isa);
} // namespace jit
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -116,7 +116,7 @@ void InitDevices(bool init_p2p, const std::vector<int> devices) { ...@@ -116,7 +116,7 @@ void InitDevices(bool init_p2p, const std::vector<int> devices) {
#endif #endif
#if !defined(_WIN32) && !defined(__APPLE__) && !defined(__OSX__) #if !defined(_WIN32) && !defined(__APPLE__) && !defined(__OSX__)
if (platform::jit::MayIUse(platform::jit::avx)) { if (platform::MayIUse(platform::avx)) {
#ifndef __AVX__ #ifndef __AVX__
LOG(WARNING) << "AVX is available, Please re-compile on local machine"; LOG(WARNING) << "AVX is available, Please re-compile on local machine";
#endif #endif
...@@ -131,10 +131,10 @@ void InitDevices(bool init_p2p, const std::vector<int> devices) { ...@@ -131,10 +131,10 @@ void InitDevices(bool init_p2p, const std::vector<int> devices) {
" version or compile from source code." " version or compile from source code."
#ifdef __AVX512F__ #ifdef __AVX512F__
if (!platform::jit::MayIUse(platform::jit::avx512f)) { if (!platform::MayIUse(platform::avx512f)) {
if (platform::jit::MayIUse(platform::jit::avx2)) { if (platform::MayIUse(platform::avx2)) {
AVX_GUIDE(AVX512, AVX2); AVX_GUIDE(AVX512, AVX2);
} else if (platform::jit::MayIUse(platform::jit::avx)) { } else if (platform::MayIUse(platform::avx)) {
AVX_GUIDE(AVX512, AVX); AVX_GUIDE(AVX512, AVX);
} else { } else {
AVX_GUIDE(AVX512, NonAVX); AVX_GUIDE(AVX512, NonAVX);
...@@ -143,8 +143,8 @@ void InitDevices(bool init_p2p, const std::vector<int> devices) { ...@@ -143,8 +143,8 @@ void InitDevices(bool init_p2p, const std::vector<int> devices) {
#endif #endif
#ifdef __AVX2__ #ifdef __AVX2__
if (!platform::jit::MayIUse(platform::jit::avx2)) { if (!platform::MayIUse(platform::avx2)) {
if (platform::jit::MayIUse(platform::jit::avx)) { if (platform::MayIUse(platform::avx)) {
AVX_GUIDE(AVX2, AVX); AVX_GUIDE(AVX2, AVX);
} else { } else {
AVX_GUIDE(AVX2, NonAVX); AVX_GUIDE(AVX2, NonAVX);
...@@ -153,7 +153,7 @@ void InitDevices(bool init_p2p, const std::vector<int> devices) { ...@@ -153,7 +153,7 @@ void InitDevices(bool init_p2p, const std::vector<int> devices) {
#endif #endif
#ifdef __AVX__ #ifdef __AVX__
if (!platform::jit::MayIUse(platform::jit::avx)) { if (!platform::MayIUse(platform::avx)) {
AVX_GUIDE(AVX, NonAVX); AVX_GUIDE(AVX, NonAVX);
} }
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册