diff --git a/paddle/fluid/operators/attention_lstm_op.cc b/paddle/fluid/operators/attention_lstm_op.cc
index 9b943440a869e213db4ed761cfe7c508bc5e94ae..75fc59125f21901b6781315eb3d7dba36b7f11f2 100644
--- a/paddle/fluid/operators/attention_lstm_op.cc
+++ b/paddle/fluid/operators/attention_lstm_op.cc
@@ -231,10 +231,10 @@ use lstm_x_t as input and compute as standard LSTM.
 template <typename T>
 inline void bias_relu(const int n, const T* x, const T* bias, T* y) {
   if (bias) {
-    math::vec_add_bias<T, platform::jit::avx>(n, *bias, x, y);
-    math::vec_relu<T, platform::jit::avx>(n, y, y);
+    math::vec_add_bias<T, platform::avx>(n, *bias, x, y);
+    math::vec_relu<T, platform::avx>(n, y, y);
   } else {
-    math::vec_relu<T, platform::jit::avx>(n, x, y);
+    math::vec_relu<T, platform::avx>(n, x, y);
   }
 }
 
@@ -245,8 +245,8 @@ inline void vec_softmax(const int n, const T* x, T* y) {
   for (int i = 1; i < n; ++i) {
     scalar = scalar < x[i] ? x[i] : scalar;
   }
-  math::vec_add_bias<T, platform::jit::avx>(n, -scalar, x, y);  // sub
-  math::vec_exp<T>(n, y, y);                                    // exp
+  math::vec_add_bias<T, platform::avx>(n, -scalar, x, y);  // sub
+  math::vec_exp<T>(n, y, y);                               // exp
   // sum
   scalar = T(0);
   for (int i = 0; i < n; ++i) {
@@ -302,13 +302,13 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
     auto& act_gate_str = ctx.Attr<std::string>("gate_activation");
     auto& act_cell_str = ctx.Attr<std::string>("cell_activation");
     auto& act_cand_str = ctx.Attr<std::string>("candidate_activation");
-    if (platform::jit::MayIUse(platform::jit::avx)) {
-      math::VecActivations<T, platform::jit::avx> act_functor;
+    if (platform::MayIUse(platform::avx)) {
+      math::VecActivations<T, platform::avx> act_functor;
       act_gate = act_functor(act_gate_str);
       act_cell = act_functor(act_cell_str);
       act_cand = act_functor(act_cand_str);
     } else {
-      math::VecActivations<T, platform::jit::isa_any> act_functor;
+      math::VecActivations<T, platform::isa_any> act_functor;
       act_gate = act_functor(act_gate_str);
       act_cell = act_functor(act_cell_str);
       act_cand = act_functor(act_cand_str);
diff --git a/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc b/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc
index 6d463538d232e1a38f845e7abc3786568ca3bb21..1eb6523a2dfb358490a07bf1b806d5638442a4d5 100644
--- a/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc
+++ b/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc
@@ -217,13 +217,13 @@ class FusedEmbeddingFCLSTMKernel : public framework::OpKernel<T> {
   auto& act_gate_str = ctx.Attr<std::string>("gate_activation");               \
   auto& act_cell_str = ctx.Attr<std::string>("cell_activation");               \
   auto& act_cand_str = ctx.Attr<std::string>("candidate_activation");          \
-  if (platform::jit::MayIUse(platform::jit::avx)) {                            \
-    math::VecActivations<T, platform::jit::avx> act_functor;                   \
+  if (platform::MayIUse(platform::avx)) {                                      \
+    math::VecActivations<T, platform::avx> act_functor;                        \
     act_gate = act_functor(act_gate_str);                                      \
     act_cell = act_functor(act_cell_str);                                      \
     act_cand = act_functor(act_cand_str);                                      \
   } else {                                                                     \
-    math::VecActivations<T, platform::jit::isa_any> act_functor;               \
+    math::VecActivations<T, platform::isa_any> act_functor;                    \
     act_gate = act_functor(act_gate_str);                                      \
     act_cell = act_functor(act_cell_str);                                      \
     act_cand = act_functor(act_cand_str);                                      \
diff --git a/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc b/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc
index 288b56fc2485138b20c5b53af3e950f1c1886ba5..17ed9771d074cf7ae8c6735e4cb859139503a0af 100644
--- a/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc
+++ b/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc
@@ -151,11 +151,11 @@ class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel<T> {
 
     std::function<void(const int, const T*, T*)> fc_act;
     auto& fc_act_str = ctx.Attr<std::string>("fc_activation");
-    if (platform::jit::MayIUse(platform::jit::avx)) {
-      math::VecActivations<T, platform::jit::avx> act_functor;
+    if (platform::MayIUse(platform::avx)) {
+      math::VecActivations<T, platform::avx> act_functor;
       fc_act = act_functor(fc_act_str);
     } else {
-      math::VecActivations<T, platform::jit::isa_any> act_functor;
+      math::VecActivations<T, platform::isa_any> act_functor;
       fc_act = act_functor(fc_act_str);
     }
 
diff --git a/paddle/fluid/operators/math/cpu_vec.h b/paddle/fluid/operators/math/cpu_vec.h
index 7d81aee596934308763002d440f52400f45b5f20..e1e4d168db3ca594b44396a6e30c5bfc03483eaf 100644
--- a/paddle/fluid/operators/math/cpu_vec.h
+++ b/paddle/fluid/operators/math/cpu_vec.h
@@ -77,7 +77,7 @@ inline void vec_scal<double>(const int n, const double a, double* x) {
 #endif
 
 // MKL scal only support inplace, choose this if src and dst are not equal
-template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any>
+template <typename T, platform::cpu_isa_t isa = platform::isa_any>
 inline void vec_scal(const int n, const T a, const T* x, T* y) {
   for (int i = 0; i < n; ++i) {
     y[i] = a * x[i];
@@ -85,12 +85,12 @@ inline void vec_scal(const int n, const T a, const T* x, T* y) {
 }
 
 template <>
-inline void vec_scal<float, platform::jit::avx>(const int n, const float a,
-                                                const float* x, float* y) {
+inline void vec_scal<float, platform::avx>(const int n, const float a,
+                                           const float* x, float* y) {
 #ifdef __AVX__
   constexpr int block = YMM_FLOAT_BLOCK;
   if (n < block) {
-    vec_scal<float, platform::jit::isa_any>(n, a, x, y);
+    vec_scal<float, platform::isa_any>(n, a, x, y);
     return;
   }
   const int rest = n % block;
@@ -114,24 +114,24 @@ inline void vec_scal<float, platform::jit::avx>(const int n, const float a,
     y[i] = a * x[i];
   }
 #else
-  vec_scal<float, platform::jit::isa_any>(n, a, x, y);
+  vec_scal<float, platform::isa_any>(n, a, x, y);
 #endif
 }
 
 template <>
-inline void vec_scal<float, platform::jit::avx2>(const int n, const float a,
-                                                 const float* x, float* y) {
-  vec_scal<float, platform::jit::avx>(n, a, x, y);
+inline void vec_scal<float, platform::avx2>(const int n, const float a,
+                                            const float* x, float* y) {
+  vec_scal<float, platform::avx>(n, a, x, y);
 }
 
 template <>
-inline void vec_scal<float, platform::jit::avx512f>(const int n, const float a,
-                                                    const float* x, float* y) {
+inline void vec_scal<float, platform::avx512f>(const int n, const float a,
+                                               const float* x, float* y) {
   // TODO(TJ): enable me
-  vec_scal<float, platform::jit::avx2>(n, a, x, y);
+  vec_scal<float, platform::avx2>(n, a, x, y);
 }
 
-template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any>
+template <typename T, platform::cpu_isa_t isa = platform::isa_any>
 inline void vec_bias_sub(const int n, const T a, const T* x, T* y) {
   for (int i = 0; i < n; ++i) {
     y[i] = a - x[i];
@@ -139,12 +139,12 @@ inline void vec_bias_sub(const int n, const T a, const T* x, T* y) {
 }
 
 template <>
-inline void vec_bias_sub<float, platform::jit::avx>(const int n, const float a,
-                                                    const float* x, float* y) {
+inline void vec_bias_sub<float, platform::avx>(const int n, const float a,
+                                               const float* x, float* y) {
 #ifdef __AVX__
   constexpr int block = YMM_FLOAT_BLOCK;
   if (n < block) {
-    vec_bias_sub<float, platform::jit::isa_any>(n, a, x, y);
+    vec_bias_sub<float, platform::isa_any>(n, a, x, y);
     return;
   }
   const int rest = n % block;
@@ -168,27 +168,25 @@ inline void vec_bias_sub<float, platform::jit::avx>(const int n, const float a,
     y[i] = a - x[i];
   }
 #else
-  vec_bias_sub<float, platform::jit::isa_any>(n, a, x, y);
+  vec_bias_sub<float, platform::isa_any>(n, a, x, y);
 #endif
 }
 
 template <>
-inline void vec_bias_sub<float, platform::jit::avx2>(const int n, const float a,
-                                                     const float* x, float* y) {
-  vec_bias_sub<float, platform::jit::avx>(n, a, x, y);
+inline void vec_bias_sub<float, platform::avx2>(const int n, const float a,
+                                                const float* x, float* y) {
+  vec_bias_sub<float, platform::avx>(n, a, x, y);
 }
 
 template <>
-inline void vec_bias_sub<float, platform::jit::avx512f>(const int n,
-                                                        const float a,
-                                                        const float* x,
-                                                        float* y) {
+inline void vec_bias_sub<float, platform::avx512f>(const int n, const float a,
+                                                   const float* x, float* y) {
   // TODO(TJ): enable me
-  vec_bias_sub<float, platform::jit::avx2>(n, a, x, y);
+  vec_bias_sub<float, platform::avx2>(n, a, x, y);
 }
 
 // out = x*y + (1-x)*z
-template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any>
+template <typename T, platform::cpu_isa_t isa = platform::isa_any>
 inline void vec_cross(const int n, const T* x, const T* y, const T* z, T* out) {
   for (int i = 0; i < n; ++i) {
     out[i] = x[i] * y[i] + (static_cast<T>(1) - x[i]) * z[i];
@@ -196,13 +194,13 @@ inline void vec_cross(const int n, const T* x, const T* y, const T* z, T* out) {
 }
 
 template <>
-inline void vec_cross<float, platform::jit::avx>(const int n, const float* x,
-                                                 const float* y, const float* z,
-                                                 float* out) {
+inline void vec_cross<float, platform::avx>(const int n, const float* x,
+                                            const float* y, const float* z,
+                                            float* out) {
 #ifdef __AVX__
   constexpr int block = YMM_FLOAT_BLOCK;
   if (n < block) {
-    vec_cross<float, platform::jit::isa_any>(n, x, y, z, out);
+    vec_cross<float, platform::isa_any>(n, x, y, z, out);
     return;
   }
   const int rest = n % block;
@@ -228,25 +226,26 @@ inline void vec_cross<float, platform::jit::avx>(const int n, const float* x,
     out[i] = x[i] * y[i] + (1.f - x[i]) * z[i];
   }
 #else
-  vec_cross<float, platform::jit::isa_any>(n, x, y, z, out);
+  vec_cross<float, platform::isa_any>(n, x, y, z, out);
 #endif
 }
 
 template <>
-inline void vec_cross<float, platform::jit::avx2>(const int n, const float* x,
-                                                  const float* y,
-                                                  const float* z, float* out) {
-  vec_cross<float, platform::jit::avx>(n, x, y, z, out);
+inline void vec_cross<float, platform::avx2>(const int n, const float* x,
+                                             const float* y, const float* z,
+                                             float* out) {
+  vec_cross<float, platform::avx>(n, x, y, z, out);
 }
 
 template <>
-inline void vec_cross<float, platform::jit::avx512f>(
-    const int n, const float* x, const float* y, const float* z, float* out) {
+inline void vec_cross<float, platform::avx512f>(const int n, const float* x,
+                                                const float* y, const float* z,
+                                                float* out) {
   // TODO(TJ): enable me
-  vec_cross<float, platform::jit::avx>(n, x, y, z, out);
+  vec_cross<float, platform::avx>(n, x, y, z, out);
 }
 
-template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any>
+template <typename T, platform::cpu_isa_t isa = platform::isa_any>
 inline void vec_add_bias(const int n, const T a, const T* x, T* y) {
   for (int i = 0; i < n; ++i) {
     y[i] = x[i] + a;
@@ -254,12 +253,12 @@ inline void vec_add_bias(const int n, const T a, const T* x, T* y) {
 }
 
 template <>
-inline void vec_add_bias<float, platform::jit::avx>(const int n, const float a,
-                                                    const float* x, float* y) {
+inline void vec_add_bias<float, platform::avx>(const int n, const float a,
+                                               const float* x, float* y) {
 #ifdef __AVX__
   constexpr int block = YMM_FLOAT_BLOCK;
   if (n < block) {
-    vec_add_bias<float, platform::jit::isa_any>(n, a, x, y);
+    vec_add_bias<float, platform::isa_any>(n, a, x, y);
     return;
   }
   const int rest = n % block;
@@ -283,32 +282,30 @@ inline void vec_add_bias<float, platform::jit::avx>(const int n, const float a,
     y[i] = x[i] + a;
   }
 #else
-  vec_add_bias<float, platform::jit::isa_any>(n, a, x, y);
+  vec_add_bias<float, platform::isa_any>(n, a, x, y);
 #endif
 }
 
 template <>
-inline void vec_add_bias<float, platform::jit::avx2>(const int n, const float a,
-                                                     const float* x, float* y) {
-  vec_add_bias<float, platform::jit::avx>(n, a, x, y);
+inline void vec_add_bias<float, platform::avx2>(const int n, const float a,
+                                                const float* x, float* y) {
+  vec_add_bias<float, platform::avx>(n, a, x, y);
 }
 
 template <>
-inline void vec_add_bias<float, platform::jit::avx512f>(const int n,
-                                                        const float a,
-                                                        const float* x,
-                                                        float* y) {
+inline void vec_add_bias<float, platform::avx512f>(const int n, const float a,
+                                                   const float* x, float* y) {
   // TODO(TJ): enable me
-  vec_add_bias<float, platform::jit::avx2>(n, a, x, y);
+  vec_add_bias<float, platform::avx2>(n, a, x, y);
 }
 
-template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any>
+template <typename T, platform::cpu_isa_t isa = platform::isa_any>
 inline void vec_identity(const int n, const T* x, T* y) {
   // do nothing
   return;
 }
 
-template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any>
+template <typename T, platform::cpu_isa_t isa = platform::isa_any>
 inline void vec_sigmoid(const int n, const T* x, T* y) {
   const T min = SIGMOID_THRESHOLD_MIN;
   const T max = SIGMOID_THRESHOLD_MAX;
@@ -323,12 +320,12 @@ inline void vec_sigmoid(const int n, const T* x, T* y) {
 }
 
 template <>
-inline void vec_sigmoid<float, platform::jit::avx>(const int n, const float* x,
-                                                   float* y) {
+inline void vec_sigmoid<float, platform::avx>(const int n, const float* x,
+                                              float* y) {
 #ifdef __AVX__
   constexpr int block = YMM_FLOAT_BLOCK;
   if (n < block) {
-    vec_sigmoid<float, platform::jit::isa_any>(n, x, y);
+    vec_sigmoid<float, platform::isa_any>(n, x, y);
     return;
   }
   const int rest = n % block;
@@ -377,25 +374,24 @@ inline void vec_sigmoid<float, platform::jit::avx>(const int n, const float* x,
     y[i] = 1.f / (1.f + y[i]);
   }
 #else
-  vec_sigmoid<float, platform::jit::isa_any>(n, x, y);
+  vec_sigmoid<float, platform::isa_any>(n, x, y);
 #endif
 }
 
 template <>
-inline void vec_sigmoid<float, platform::jit::avx2>(const int n, const float* x,
-                                                    float* y) {
-  vec_sigmoid<float, platform::jit::avx>(n, x, y);
+inline void vec_sigmoid<float, platform::avx2>(const int n, const float* x,
+                                               float* y) {
+  vec_sigmoid<float, platform::avx>(n, x, y);
 }
 
 template <>
-inline void vec_sigmoid<float, platform::jit::avx512f>(const int n,
-                                                       const float* x,
-                                                       float* y) {
+inline void vec_sigmoid<float, platform::avx512f>(const int n, const float* x,
+                                                  float* y) {
   // TODO(TJ): enable me
-  vec_sigmoid<float, platform::jit::avx2>(n, x, y);
+  vec_sigmoid<float, platform::avx2>(n, x, y);
 }
 
-template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any>
+template <typename T, platform::cpu_isa_t isa = platform::isa_any>
 inline void vec_tanh(const int n, const T* x, T* y) {
   vec_scal<T, isa>(n, static_cast<T>(2), x, y);
   vec_sigmoid<T, isa>(n, y, y);
@@ -404,7 +400,7 @@ inline void vec_tanh(const int n, const T* x, T* y) {
 }
 
 // TODO(TJ): make relu clip
-template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any>
+template <typename T, platform::cpu_isa_t isa = platform::isa_any>
 inline void vec_relu(const int n, const T* x, T* y) {
   for (int i = 0; i < n; ++i) {
     y[i] = x[i] > 0 ? x[i] : 0;
@@ -412,12 +408,12 @@ inline void vec_relu(const int n, const T* x, T* y) {
 }
 
 template <>
-inline void vec_relu<float, platform::jit::avx>(const int n, const float* x,
-                                                float* y) {
+inline void vec_relu<float, platform::avx>(const int n, const float* x,
+                                           float* y) {
 #ifdef __AVX__
   constexpr int block = YMM_FLOAT_BLOCK;
   if (n < block * 4) {
-    vec_relu<float, platform::jit::isa_any>(n, x, y);
+    vec_relu<float, platform::isa_any>(n, x, y);
     return;
   }
 
@@ -441,26 +437,26 @@ inline void vec_relu<float, platform::jit::avx>(const int n, const float* x,
 #undef MOVE_ONE_STEP
 
 #else
-  vec_relu<float, platform::jit::isa_any>(n, x, y);
+  vec_relu<float, platform::isa_any>(n, x, y);
 #endif
 }
 
 template <>
-inline void vec_relu<float, platform::jit::avx2>(const int n, const float* x,
-                                                 float* y) {
-  vec_relu<float, platform::jit::avx>(n, x, y);
+inline void vec_relu<float, platform::avx2>(const int n, const float* x,
+                                            float* y) {
+  vec_relu<float, platform::avx>(n, x, y);
 }
 
 template <>
-inline void vec_relu<float, platform::jit::avx512f>(const int n, const float* x,
-                                                    float* y) {
+inline void vec_relu<float, platform::avx512f>(const int n, const float* x,
+                                               float* y) {
   // TODO(TJ): enable me
-  vec_relu<float, platform::jit::avx2>(n, x, y);
+  vec_relu<float, platform::avx2>(n, x, y);
 }
 
 // TODO(TJ): optimize double of sigmoid, tanh and relu if necessary
 
-template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any>
+template <typename T, platform::cpu_isa_t isa = platform::isa_any>
 class VecActivations {
  public:
   std::function<void(const int, const T*, T*)> operator()(
diff --git a/paddle/fluid/operators/math/cpu_vec_test.cc b/paddle/fluid/operators/math/cpu_vec_test.cc
index c37fa291a259550a3cb6d4f3dd9d5a415c3a2130..28eb9cadc9d4258bf4f8f71a06e029531e448014 100644
--- a/paddle/fluid/operators/math/cpu_vec_test.cc
+++ b/paddle/fluid/operators/math/cpu_vec_test.cc
@@ -104,38 +104,42 @@ void TestAndBench(const int n, std::function<void(const int, const T*, T*)> tgt,
 }
 
 TEST(CpuVecTest, sigmoid) {
-  namespace jit = paddle::platform::jit;
+  namespace platform = paddle::platform;
   using namespace paddle::operators::math;  // NOLINT
   for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) {
     TestAndBench<float>(sz, vec_sigmoid<float>, ref_sigmoid<float>);
-    TestAndBench<float>(sz, vec_sigmoid<float, jit::avx>, ref_sigmoid<float>);
-    TestAndBench<float>(sz, vec_sigmoid<float, jit::avx2>, ref_sigmoid<float>);
-    TestAndBench<float>(sz, vec_sigmoid<float, jit::avx512f>,
+    TestAndBench<float>(sz, vec_sigmoid<float, platform::avx>,
+                        ref_sigmoid<float>);
+    TestAndBench<float>(sz, vec_sigmoid<float, platform::avx2>,
+                        ref_sigmoid<float>);
+    TestAndBench<float>(sz, vec_sigmoid<float, platform::avx512f>,
                         ref_sigmoid<float>);
   }
   TestAndBench<double>(30, vec_sigmoid<double>, ref_sigmoid<double>);
 }
 
 TEST(CpuVecTest, tanh) {
-  namespace jit = paddle::platform::jit;
+  namespace platform = paddle::platform;
   using namespace paddle::operators::math;  // NOLINT
   for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) {
     TestAndBench<float>(sz, vec_tanh<float>, ref_tanh<float>);
-    TestAndBench<float>(sz, vec_tanh<float, jit::avx>, ref_tanh<float>);
-    TestAndBench<float>(sz, vec_tanh<float, jit::avx2>, ref_tanh<float>);
-    TestAndBench<float>(sz, vec_tanh<float, jit::avx512f>, ref_tanh<float>);
+    TestAndBench<float>(sz, vec_tanh<float, platform::avx>, ref_tanh<float>);
+    TestAndBench<float>(sz, vec_tanh<float, platform::avx2>, ref_tanh<float>);
+    TestAndBench<float>(sz, vec_tanh<float, platform::avx512f>,
+                        ref_tanh<float>);
   }
   TestAndBench<double>(30, vec_tanh<double>, ref_tanh<double>);
 }
 
 TEST(CpuVecTest, relu) {
-  namespace jit = paddle::platform::jit;
+  namespace platform = paddle::platform;
   using namespace paddle::operators::math;  // NOLINT
   for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) {
     TestAndBench<float>(sz, vec_relu<float>, ref_relu<float>);
-    TestAndBench<float>(sz, vec_relu<float, jit::avx>, ref_relu<float>);
-    TestAndBench<float>(sz, vec_relu<float, jit::avx2>, ref_relu<float>);
-    TestAndBench<float>(sz, vec_relu<float, jit::avx512f>, ref_relu<float>);
+    TestAndBench<float>(sz, vec_relu<float, platform::avx>, ref_relu<float>);
+    TestAndBench<float>(sz, vec_relu<float, platform::avx2>, ref_relu<float>);
+    TestAndBench<float>(sz, vec_relu<float, platform::avx512f>,
+                        ref_relu<float>);
   }
   TestAndBench<double>(30, vec_relu<double>, ref_relu<double>);
 }
@@ -162,38 +166,40 @@ void TestInplace(const int n, std::function<void(const int, const T*, T*)> tgt,
 }
 
 TEST(CpuVecTest, inplace_sigmoid) {
-  namespace jit = paddle::platform::jit;
+  namespace platform = paddle::platform;
   using namespace paddle::operators::math;  // NOLINT
   for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) {
     TestInplace<float>(sz, vec_sigmoid<float>, ref_sigmoid<float>);
-    TestInplace<float>(sz, vec_sigmoid<float, jit::avx>, ref_sigmoid<float>);
-    TestInplace<float>(sz, vec_sigmoid<float, jit::avx2>, ref_sigmoid<float>);
-    TestInplace<float>(sz, vec_sigmoid<float, jit::avx512f>,
+    TestInplace<float>(sz, vec_sigmoid<float, platform::avx>,
+                       ref_sigmoid<float>);
+    TestInplace<float>(sz, vec_sigmoid<float, platform::avx2>,
+                       ref_sigmoid<float>);
+    TestInplace<float>(sz, vec_sigmoid<float, platform::avx512f>,
                        ref_sigmoid<float>);
   }
   TestInplace<double>(30, vec_sigmoid<double>, ref_sigmoid<double>);
 }
 
 TEST(CpuVecTest, inplace_tanh) {
-  namespace jit = paddle::platform::jit;
+  namespace platform = paddle::platform;
   using namespace paddle::operators::math;  // NOLINT
   for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) {
     TestInplace<float>(sz, vec_tanh<float>, ref_tanh<float>);
-    TestInplace<float>(sz, vec_tanh<float, jit::avx>, ref_tanh<float>);
-    TestInplace<float>(sz, vec_tanh<float, jit::avx2>, ref_tanh<float>);
-    TestInplace<float>(sz, vec_tanh<float, jit::avx512f>, ref_tanh<float>);
+    TestInplace<float>(sz, vec_tanh<float, platform::avx>, ref_tanh<float>);
+    TestInplace<float>(sz, vec_tanh<float, platform::avx2>, ref_tanh<float>);
+    TestInplace<float>(sz, vec_tanh<float, platform::avx512f>, ref_tanh<float>);
   }
   TestInplace<double>(30, vec_tanh<double>, ref_tanh<double>);
 }
 
 TEST(CpuVecTest, inplace_relu) {
-  namespace jit = paddle::platform::jit;
+  namespace platform = paddle::platform;
   using namespace paddle::operators::math;  // NOLINT
   for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) {
     TestInplace<float>(sz, vec_relu<float>, ref_relu<float>);
-    TestInplace<float>(sz, vec_relu<float, jit::avx>, ref_relu<float>);
-    TestInplace<float>(sz, vec_relu<float, jit::avx2>, ref_relu<float>);
-    TestInplace<float>(sz, vec_relu<float, jit::avx512f>, ref_relu<float>);
+    TestInplace<float>(sz, vec_relu<float, platform::avx>, ref_relu<float>);
+    TestInplace<float>(sz, vec_relu<float, platform::avx2>, ref_relu<float>);
+    TestInplace<float>(sz, vec_relu<float, platform::avx512f>, ref_relu<float>);
   }
   TestInplace<double>(30, vec_relu<double>, ref_relu<double>);
 }
diff --git a/paddle/fluid/operators/math/jit_code.cc b/paddle/fluid/operators/math/jit_code.cc
index 52cbdf685dee651cbc1490dc6faacb8680004c89..78d0c3e8808f0daf6a18d2217664e965773b95ff 100644
--- a/paddle/fluid/operators/math/jit_code.cc
+++ b/paddle/fluid/operators/math/jit_code.cc
@@ -22,7 +22,7 @@ namespace math {
 namespace jitkernel {
 namespace gen {
 
-using namespace platform::jit;  // NOLINT
+using namespace platform;  // NOLINT
 
 bool VXXJitCode::init(int d, int scalar_index) {
   // It's not necessary to use avx512 since it would slow down the frequency
diff --git a/paddle/fluid/operators/math/jit_code.h b/paddle/fluid/operators/math/jit_code.h
index a9214621295a7740b804b26c02d216dd5118d8bb..e2b4761435594fdc952ff5dba5b5fa4f4aa98e6c 100644
--- a/paddle/fluid/operators/math/jit_code.h
+++ b/paddle/fluid/operators/math/jit_code.h
@@ -179,7 +179,7 @@ class VActJitCode : public JitCode {
   template <typename JMM>
   void exp_jmm(JMM& dst, JMM& src, int src_idx = 11, int fx_idx = 12,  // NOLINT
                int fy_idx = 13, int mask_idx = 14, int tmp_idx = 15) {
-    using namespace platform::jit;  // NOLINT
+    using namespace platform;  // NOLINT
     // check all idx can not equal
     JMM jmm_src = JMM(src_idx);
     JMM jmm_fx = JMM(fx_idx);
diff --git a/paddle/fluid/operators/math/jit_gen.cc b/paddle/fluid/operators/math/jit_gen.cc
index 6af39518ed926554c8c839bba701d3827923dba0..5c6672928e8c03ccb1920bd828f785084e422fc2 100644
--- a/paddle/fluid/operators/math/jit_gen.cc
+++ b/paddle/fluid/operators/math/jit_gen.cc
@@ -36,7 +36,7 @@ void JitCode::preCode() {
   for (int i = 0; i < num_g_abi_regs; ++i) {
     push(Xbyak::Reg64(g_abi_regs[i]));
   }
-  if (platform::jit::MayIUse(platform::jit::avx512f)) {
+  if (platform::MayIUse(platform::avx512f)) {
     mov(reg_EVEX_max_8b_offt, 2 * EVEX_max_8b_offt);
   }
 }
diff --git a/paddle/fluid/operators/math/jit_kernel.cc b/paddle/fluid/operators/math/jit_kernel.cc
index 68b708b345334bc63b5e2e88c308d20ca6378e6b..118696ba47986e2dbf97535333c9817b7c264a54 100644
--- a/paddle/fluid/operators/math/jit_kernel.cc
+++ b/paddle/fluid/operators/math/jit_kernel.cc
@@ -21,8 +21,6 @@ namespace operators {
 namespace math {
 namespace jitkernel {
 
-namespace jit = platform::jit;
-
 KernelPool& KernelPool::Instance() {
   static thread_local KernelPool g_jit_kernels;
   return g_jit_kernels;
diff --git a/paddle/fluid/operators/math/jit_kernel_blas.cc b/paddle/fluid/operators/math/jit_kernel_blas.cc
index a0f93fd8e7eb7d81211724a6991a681e2a0ed9ce..8cf588efba52314650bfd376b95b10e6d4336b2e 100644
--- a/paddle/fluid/operators/math/jit_kernel_blas.cc
+++ b/paddle/fluid/operators/math/jit_kernel_blas.cc
@@ -30,7 +30,6 @@ namespace paddle {
 namespace operators {
 namespace math {
 namespace jitkernel {
-namespace jit = platform::jit;
 
 #ifdef PADDLE_WITH_MKLML
 template <typename T>
@@ -125,7 +124,7 @@ bool VMulKernelImpl<float>::useJIT(int d) {
 #ifdef PADDLE_WITH_MKLML
 template <>
 bool VMulKernelImpl<float>::useMKL(int d) {
-  return jit::MayIUse(jit::avx512f) && d > 512;
+  return platform::MayIUse(platform::avx512f) && d > 512;
 }
 
 template <>
diff --git a/paddle/fluid/operators/math/jit_kernel_crf_decode.cc b/paddle/fluid/operators/math/jit_kernel_crf_decode.cc
index 4d26b81948238f18b097f535534fcfe9049b93c3..eeb305a88bee8f0e21b205684d24b19ca4631f65 100644
--- a/paddle/fluid/operators/math/jit_kernel_crf_decode.cc
+++ b/paddle/fluid/operators/math/jit_kernel_crf_decode.cc
@@ -25,10 +25,8 @@ namespace operators {
 namespace math {
 namespace jitkernel {
 
-namespace jit = platform::jit;
-
 /* CRF Decode JitKernel */
-template <typename T, platform::jit::cpu_isa_t isa, jit_block>
+template <typename T, platform::cpu_isa_t isa, jit_block>
 class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
  public:
   explicit CRFDecodeKernelImpl(int tag_num) : CRFDecodeKernel<T>() {
@@ -101,7 +99,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
 
 #define INTRIAVX_FLOAT(block)                                                  \
   template <>                                                                  \
-  CRFDecodeKernelImpl<float, jit::avx, block>::CRFDecodeKernelImpl(            \
+  CRFDecodeKernelImpl<float, platform::avx, block>::CRFDecodeKernelImpl(       \
       int tag_num)                                                             \
       : CRFDecodeKernel<float>() {                                             \
     this->num_ = tag_num;                                                      \
@@ -109,7 +107,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
     this->rest_ = this->num_ % YMM_FLOAT_BLOCK;                                \
   }                                                                            \
   template <>                                                                  \
-  void CRFDecodeKernelImpl<float, jit::avx, block>::Compute(                   \
+  void CRFDecodeKernelImpl<float, platform::avx, block>::Compute(              \
       const int seq_len, const float* x, const float* w, float* alpha,         \
       int* track) const {                                                      \
     INIT_ALPHA(YMM_FLOAT_BLOCK)                                                \
@@ -204,7 +202,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
 
 #define INTRIAVX512_FLOAT(block)                                               \
   template <>                                                                  \
-  CRFDecodeKernelImpl<float, jit::avx512f, block>::CRFDecodeKernelImpl(        \
+  CRFDecodeKernelImpl<float, platform::avx512f, block>::CRFDecodeKernelImpl(   \
       int tag_num)                                                             \
       : CRFDecodeKernel<float>() {                                             \
     this->num_ = tag_num;                                                      \
@@ -212,7 +210,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
     this->rest_ = this->num_ % ZMM_FLOAT_BLOCK;                                \
   }                                                                            \
   template <>                                                                  \
-  void CRFDecodeKernelImpl<float, jit::avx512f, block>::Compute(               \
+  void CRFDecodeKernelImpl<float, platform::avx512f, block>::Compute(          \
       const int seq_len, const float* x, const float* w, float* alpha,         \
       int* track) const {                                                      \
     INIT_ALPHA(ZMM_FLOAT_BLOCK)                                                \
@@ -270,14 +268,14 @@ INTRIAVX_FLOAT(kEQ16);
 INTRIAVX_FLOAT(kGT16);
 #endif
 #ifdef __AVX2__
-INTRIAVX2_FLOAT(jit::avx2, kEQ8);
-INTRIAVX2_FLOAT(jit::avx2, kGT8LT16);
-INTRIAVX2_FLOAT(jit::avx2, kEQ16);
-INTRIAVX2_FLOAT(jit::avx2, kGT16);
+INTRIAVX2_FLOAT(platform::avx2, kEQ8);
+INTRIAVX2_FLOAT(platform::avx2, kGT8LT16);
+INTRIAVX2_FLOAT(platform::avx2, kEQ16);
+INTRIAVX2_FLOAT(platform::avx2, kGT16);
 #endif
 #ifdef __AVX512F__
-INTRIAVX2_FLOAT(jit::avx512f, kEQ8);
-INTRIAVX2_FLOAT(jit::avx512f, kGT8LT16);
+INTRIAVX2_FLOAT(platform::avx512f, kEQ8);
+INTRIAVX2_FLOAT(platform::avx512f, kGT8LT16);
 INTRIAVX512_FLOAT(kEQ16);
 INTRIAVX512_FLOAT(kGT16);
 #endif
diff --git a/paddle/fluid/operators/math/jit_kernel_exp.cc b/paddle/fluid/operators/math/jit_kernel_exp.cc
index 686f3dd9836cb9192088771753065c6add639620..7945cfb253a61b7d1191c39537254126e2bb85dd 100644
--- a/paddle/fluid/operators/math/jit_kernel_exp.cc
+++ b/paddle/fluid/operators/math/jit_kernel_exp.cc
@@ -29,7 +29,6 @@ namespace paddle {
 namespace operators {
 namespace math {
 namespace jitkernel {
-namespace jit = platform::jit;
 
 #ifdef PADDLE_WITH_MKLML
 // try to use MKL to speedup
diff --git a/paddle/fluid/operators/math/jit_kernel_layer_norm.cc b/paddle/fluid/operators/math/jit_kernel_layer_norm.cc
index 49904e6e8c7cd346bcbfb67c3a7574118b36e058..fead13ebadcd131afafc308740cdd39b1c53bc08 100644
--- a/paddle/fluid/operators/math/jit_kernel_layer_norm.cc
+++ b/paddle/fluid/operators/math/jit_kernel_layer_norm.cc
@@ -22,10 +22,8 @@ namespace operators {
 namespace math {
 namespace jitkernel {
 
-namespace jit = platform::jit;
-
 /* Layer Norm JitKernel */
-template <typename T, platform::jit::cpu_isa_t isa, jit_block>
+template <typename T, platform::cpu_isa_t isa, jit_block>
 class LayerNormKernelImpl : public LayerNormKernel<T> {
  public:
   explicit LayerNormKernelImpl(int right) : LayerNormKernel<T>() {
@@ -90,7 +88,7 @@ class LayerNormKernelImpl : public LayerNormKernel<T> {
     this->end_ = this->num_ - this->rest_;                                     \
   }                                                                            \
   template <>                                                                  \
-  void LayerNormKernelImpl<float, jit::avx, block>::Compute(                   \
+  void LayerNormKernelImpl<float, platform::avx, block>::Compute(              \
       float* x, float* out, float* mean, float* var, const float* scale,       \
       const float* bias, int height, const float epsilon) const {              \
     __m256 sum;                                                                \
@@ -219,16 +217,16 @@ class LayerNormKernelImpl : public LayerNormKernel<T> {
   }
 
 #ifdef __AVX__
-INTRIAVX_FLOAT(jit::avx, kEQ8);
-INTRIAVX_FLOAT(jit::avx, kGT8LT16);
-INTRIAVX_FLOAT(jit::avx, kEQ16);
-INTRIAVX_FLOAT(jit::avx, kGT16);
+INTRIAVX_FLOAT(platform::avx, kEQ8);
+INTRIAVX_FLOAT(platform::avx, kGT8LT16);
+INTRIAVX_FLOAT(platform::avx, kEQ16);
+INTRIAVX_FLOAT(platform::avx, kGT16);
 #endif
 #ifdef __AVX2__
-INTRIAVX_FLOAT(jit::avx2, kEQ8);
-INTRIAVX_FLOAT(jit::avx2, kGT8LT16);
-INTRIAVX_FLOAT(jit::avx2, kEQ16);
-INTRIAVX_FLOAT(jit::avx2, kGT16);
+INTRIAVX_FLOAT(platform::avx2, kEQ8);
+INTRIAVX_FLOAT(platform::avx2, kGT8LT16);
+INTRIAVX_FLOAT(platform::avx2, kEQ16);
+INTRIAVX_FLOAT(platform::avx2, kGT16);
 #endif
 
 #undef INTRIAVX_FLOAT
diff --git a/paddle/fluid/operators/math/jit_kernel_macro.h b/paddle/fluid/operators/math/jit_kernel_macro.h
index 5a3efd979f803d396a5084c199b1d71b88a77126..4dba3b56810794cb4839d26386ae77a8f4507977 100644
--- a/paddle/fluid/operators/math/jit_kernel_macro.h
+++ b/paddle/fluid/operators/math/jit_kernel_macro.h
@@ -92,7 +92,6 @@ namespace jitkernel {
                           JITKERNEL_DECLARE, JITKERNEL_FIND_KEY,     \
                           JITKERNEL_IMPL)
 
-namespace jit = platform::jit;
 // TODO(TJ): below defines are deprecated, would be remove recently
 #define SEARCH_BLOCK(macro_, ker, dtype, isa)              \
   if (d < YMM_FLOAT_BLOCK) {                               \
@@ -107,15 +106,15 @@ namespace jit = platform::jit;
     macro_(ker, dtype, isa, kGT16);                        \
   }
 
-#define SEARCH_ISA_BLOCK(macro_, ker, dtype)        \
-  if (jit::MayIUse(jit::avx512f)) {                 \
-    SEARCH_BLOCK(macro_, ker, dtype, jit::avx512f); \
-  } else if (jit::MayIUse(jit::avx2)) {             \
-    SEARCH_BLOCK(macro_, ker, dtype, jit::avx2);    \
-  } else if (jit::MayIUse(jit::avx)) {              \
-    SEARCH_BLOCK(macro_, ker, dtype, jit::avx);     \
-  } else {                                          \
-    SEARCH_BLOCK(macro_, ker, dtype, jit::isa_any); \
+#define SEARCH_ISA_BLOCK(macro_, ker, dtype)             \
+  if (platform::MayIUse(platform::avx512f)) {            \
+    SEARCH_BLOCK(macro_, ker, dtype, platform::avx512f); \
+  } else if (platform::MayIUse(platform::avx2)) {        \
+    SEARCH_BLOCK(macro_, ker, dtype, platform::avx2);    \
+  } else if (platform::MayIUse(platform::avx)) {         \
+    SEARCH_BLOCK(macro_, ker, dtype, platform::avx);     \
+  } else {                                               \
+    SEARCH_BLOCK(macro_, ker, dtype, platform::isa_any); \
   }
 
 #define JITKERNEL_KEY(ker_key, dtype_key) \
@@ -156,10 +155,10 @@ namespace jit = platform::jit;
                                   marco_declare, macro_key, macro_impl)
 
 #define FOR_EACH_ISA(macro_, block) \
-  macro_(jit::avx512f, block);      \
-  macro_(jit::avx2, block);         \
-  macro_(jit::avx, block);          \
-  macro_(jit::isa_any, block)
+  macro_(platform::avx512f, block); \
+  macro_(platform::avx2, block);    \
+  macro_(platform::avx, block);     \
+  macro_(platform::isa_any, block)
 
 #define FOR_EACH_BLOCK(macro_, isa) \
   macro_(isa, kLT8);                \
@@ -168,11 +167,11 @@ namespace jit = platform::jit;
   macro_(isa, kEQ16);               \
   macro_(isa, kGT16)
 
-#define FOR_EACH_ISA_BLOCK(macro_)      \
-  FOR_EACH_BLOCK(macro_, jit::avx512f); \
-  FOR_EACH_BLOCK(macro_, jit::avx2);    \
-  FOR_EACH_BLOCK(macro_, jit::avx);     \
-  FOR_EACH_BLOCK(macro_, jit::isa_any)
+#define FOR_EACH_ISA_BLOCK(macro_)           \
+  FOR_EACH_BLOCK(macro_, platform::avx512f); \
+  FOR_EACH_BLOCK(macro_, platform::avx2);    \
+  FOR_EACH_BLOCK(macro_, platform::avx);     \
+  FOR_EACH_BLOCK(macro_, platform::isa_any)
 
 }  // namespace jitkernel
 }  // namespace math
diff --git a/paddle/fluid/operators/math/jit_kernel_test.cc b/paddle/fluid/operators/math/jit_kernel_test.cc
index ed86a47e159cacd4f5572e22c7633f725aaeb516..19f7bd8909499c12fd5bee4db0d0a71a632e7f19 100644
--- a/paddle/fluid/operators/math/jit_kernel_test.cc
+++ b/paddle/fluid/operators/math/jit_kernel_test.cc
@@ -705,7 +705,7 @@ TEST(JitKernel, pool) {
   jit::lstm_attr_t attr(frame_size, act_gate, act_cand, act_cell, false);
 
   // empty call it to avoid unknown flag 'use_pinned_memory' on Mac
-  paddle::platform::jit::MayIUse(paddle::platform::jit::avx);
+  paddle::platform::MayIUse(paddle::platform::avx);
   const auto& plstm1 =
       jit::KernelPool::Instance()
           .template Get<jit::LSTMKernel<float>, const jit::lstm_attr_t&>(attr);
diff --git a/paddle/fluid/platform/cpu_info.cc b/paddle/fluid/platform/cpu_info.cc
index d466f28d1ea0a8327f8d7a45c3e55c5aacd61544..f9a32bfa4c15261ba6b79fc4efd3a1961f7c6d4d 100644
--- a/paddle/fluid/platform/cpu_info.cc
+++ b/paddle/fluid/platform/cpu_info.cc
@@ -123,7 +123,6 @@ size_t CUDAPinnedMaxChunkSize() {
   return CUDAPinnedMaxAllocSize() / 256;
 }
 
-namespace jit {
 #ifdef PADDLE_WITH_XBYAK
 static Xbyak::util::Cpu cpu;
 bool MayIUse(const cpu_isa_t cpu_isa) {
@@ -165,6 +164,5 @@ bool MayIUse(const cpu_isa_t cpu_isa) {
 }
 #endif
 
-}  // namespace jit
 }  // namespace platform
 }  // namespace paddle
diff --git a/paddle/fluid/platform/cpu_info.h b/paddle/fluid/platform/cpu_info.h
index fd31ef77b46d5b5b641983a0421da31914c87c18..55dba545ff133b1c219ee58f6d1bb2d2130d1a59 100644
--- a/paddle/fluid/platform/cpu_info.h
+++ b/paddle/fluid/platform/cpu_info.h
@@ -39,7 +39,6 @@ size_t CUDAPinnedMinChunkSize();
 //! Get the maximum chunk size for buddy allocator.
 size_t CUDAPinnedMaxChunkSize();
 
-namespace jit {
 typedef enum {
   isa_any,
   sse42,
@@ -55,7 +54,5 @@ typedef enum {
 // May I use some instruction
 bool MayIUse(const cpu_isa_t cpu_isa);
 
-}  // namespace jit
-
 }  // namespace platform
 }  // namespace paddle
diff --git a/paddle/fluid/platform/init.cc b/paddle/fluid/platform/init.cc
index 51b46450e419c6641b415a58b7551f7dd56627b2..0d10d82d74a2011b1b2bc088fe88cbfdb49600b8 100644
--- a/paddle/fluid/platform/init.cc
+++ b/paddle/fluid/platform/init.cc
@@ -116,7 +116,7 @@ void InitDevices(bool init_p2p, const std::vector<int> devices) {
 #endif
 
 #if !defined(_WIN32) && !defined(__APPLE__) && !defined(__OSX__)
-  if (platform::jit::MayIUse(platform::jit::avx)) {
+  if (platform::MayIUse(platform::avx)) {
 #ifndef __AVX__
     LOG(WARNING) << "AVX is available, Please re-compile on local machine";
 #endif
@@ -131,10 +131,10 @@ void InitDevices(bool init_p2p, const std::vector<int> devices) {
          " version or compile from source code."
 
 #ifdef __AVX512F__
-  if (!platform::jit::MayIUse(platform::jit::avx512f)) {
-    if (platform::jit::MayIUse(platform::jit::avx2)) {
+  if (!platform::MayIUse(platform::avx512f)) {
+    if (platform::MayIUse(platform::avx2)) {
       AVX_GUIDE(AVX512, AVX2);
-    } else if (platform::jit::MayIUse(platform::jit::avx)) {
+    } else if (platform::MayIUse(platform::avx)) {
       AVX_GUIDE(AVX512, AVX);
     } else {
       AVX_GUIDE(AVX512, NonAVX);
@@ -143,8 +143,8 @@ void InitDevices(bool init_p2p, const std::vector<int> devices) {
 #endif
 
 #ifdef __AVX2__
-  if (!platform::jit::MayIUse(platform::jit::avx2)) {
-    if (platform::jit::MayIUse(platform::jit::avx)) {
+  if (!platform::MayIUse(platform::avx2)) {
+    if (platform::MayIUse(platform::avx)) {
       AVX_GUIDE(AVX2, AVX);
     } else {
       AVX_GUIDE(AVX2, NonAVX);
@@ -153,7 +153,7 @@ void InitDevices(bool init_p2p, const std::vector<int> devices) {
 #endif
 
 #ifdef __AVX__
-  if (!platform::jit::MayIUse(platform::jit::avx)) {
+  if (!platform::MayIUse(platform::avx)) {
     AVX_GUIDE(AVX, NonAVX);
   }
 #endif