diff --git a/paddle/fluid/operators/attention_lstm_op.cc b/paddle/fluid/operators/attention_lstm_op.cc index 1cb65346ee2b755b48f8dd8f1456a32861c3a0b6..8bab37c5830dfdcd5d6ccf1cc049387b496b0d04 100644 --- a/paddle/fluid/operators/attention_lstm_op.cc +++ b/paddle/fluid/operators/attention_lstm_op.cc @@ -232,40 +232,28 @@ use lstm_x_t as input and compute as standard LSTM. template inline void bias_relu(const int n, const T* x, const T* bias, T* y) { if (bias) { - for (int i = 0; i < n; ++i) { - y[i] = x[i] + bias[0]; - } - math::vec_relu(n, y, y); + math::vec_add_bias(n, *bias, x, y); + math::vec_relu(n, y, y); } else { - math::vec_relu(n, x, y); + math::vec_relu(n, x, y); } } -template -inline void vec_softmax(const math::BlasT& blas, const int n, - const T* x, T* y) { +template +inline void vec_softmax(const int n, const T* x, T* y) { T scalar = x[0]; // max for (int i = 1; i < n; ++i) { scalar = scalar < x[i] ? x[i] : scalar; } - - // sub - for (int i = 0; i < n; ++i) { - y[i] = x[i] - scalar; - } - - // exp - blas.VEXP(n, y, y); - + math::vec_add_bias(n, -scalar, x, y); // sub + math::vec_exp(n, y, y); // exp // sum scalar = T(0); for (int i = 0; i < n; ++i) { scalar += y[i]; } - - // scale - blas.SCAL(n, static_cast(1) / scalar, y); + math::vec_scal(n, static_cast(1) / scalar, y); // scale } template @@ -311,11 +299,21 @@ class AttentionLSTMKernel : public framework::OpKernel { PADDLE_ENFORCE_EQ(c0->dims()[0], N, "C0 dims should be %d x %d.", N, D); fc_out->Resize({max_seq_len, 1}); - math::VecActivations act_functor; std::function act_gate, act_cell, act_cand; - act_gate = act_functor(ctx.Attr("gate_activation")); - act_cell = act_functor(ctx.Attr("cell_activation")); - act_cand = act_functor(ctx.Attr("candidate_activation")); + auto& act_gate_str = ctx.Attr("gate_activation"); + auto& act_cell_str = ctx.Attr("cell_activation"); + auto& act_cand_str = ctx.Attr("candidate_activation"); + if (platform::jit::MayIUse(platform::jit::avx)) { + math::VecActivations act_functor; + act_gate = act_functor(act_gate_str); + act_cell = act_functor(act_cell_str); + act_cand = act_functor(act_cand_str); + } else { + math::VecActivations act_functor; + act_gate = act_functor(act_gate_str); + act_cell = act_functor(act_cell_str); + act_cand = act_functor(act_cand_str); + } const T* x_data = x->data(); const T* h0_data = h0 ? h0->data() : NULL; @@ -363,7 +361,7 @@ class AttentionLSTMKernel : public framework::OpKernel { fc_out_data); } // 1d. softmax - vec_softmax(blas, seq_len, fc_out_data, fc_out_data); + vec_softmax(seq_len, fc_out_data, fc_out_data); // mul x(seq_len*M) and sum pool math::FCCompute(blas, 1, M, seq_len, fc_out_data, cur_x_data, lstm_x_data); diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index d2b772d11379c218be77277b89f3ded7b59ab9f3..1b75df5d7d97e54dfdc461660e53a368311e3778 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -65,3 +65,4 @@ if(WITH_GPU) nv_test(selected_rows_functor_gpu_test SRCS selected_rows_functor_test.cu DEPS selected_rows_functor math_function) endif() cc_test(concat_test SRCS concat_test.cc DEPS concat) +cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info) diff --git a/paddle/fluid/operators/math/cpu_vec.h b/paddle/fluid/operators/math/cpu_vec.h index 48c0da0e368a0fe6efcd758536e5659eeee26f7e..0bae926e9892986b59c6dfc7fa9b8778da1dfcb7 100644 --- a/paddle/fluid/operators/math/cpu_vec.h +++ b/paddle/fluid/operators/math/cpu_vec.h @@ -13,8 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include #include #include "paddle/fluid/platform/cpu_info.h" +#ifdef __AVX__ +#include +#endif + +#ifdef PADDLE_WITH_MKLML +#include "paddle/fluid/platform/dynload/mklml.h" +#endif namespace paddle { namespace operators { @@ -22,16 +30,161 @@ namespace math { #define SIGMOID_THRESHOLD_MIN -40.0 #define SIGMOID_THRESHOLD_MAX 13.0 -#define EXP_MAX_INPUT 40.0 + +#define AVX_FLOAT_BLOCK 8 +#define AVX_DOUBLE_BLOCK 4 +#define AVX2_FLOAT_BLOCK 8 +#define AVX2_DOUBLE_BLOCK 4 +#define AVX512_FLOAT_BLOCK 16 +#define AVX512_DOUBLE_BLOCK 8 template -inline T sigmoid(T x) { - return 1. / (1. + exp(-x)); +inline void vec_exp(const int n, const T* x, T* y) { + for (int i = 0; i < n; ++i) { + y[i] = std::exp(x[i]); + } } template -inline T tanh(T x) { - return 2. * sigmoid(2. * x) - 1.; +inline void vec_scal(const int n, const T a, T* x) { + for (int i = 0; i < n; ++i) { + x[i] = a * x[i]; + } +} + +#ifdef PADDLE_WITH_MKLML +template <> +inline void vec_exp(const int n, const float* x, float* y) { + platform::dynload::vsExp(n, x, y); +} + +template <> +inline void vec_exp(const int n, const double* x, double* y) { + platform::dynload::vdExp(n, x, y); +} + +template <> +inline void vec_scal(const int n, const float a, float* x) { + platform::dynload::cblas_sscal(n, a, x, 1); +} + +template <> +inline void vec_scal(const int n, const double a, double* x) { + platform::dynload::cblas_dscal(n, a, x, 1); +} +#endif + +// MKL scal only support inplace, choose this if src and dst are not equal +template +inline void vec_scal(const int n, const T a, const T* x, T* y) { + for (int i = 0; i < n; ++i) { + y[i] = a * x[i]; + } +} + +template <> +inline void vec_scal(const int n, const float a, + const float* x, float* y) { +#ifdef __AVX__ + constexpr int block = AVX_FLOAT_BLOCK; + if (n < block) { + vec_scal(n, a, x, y); + return; + } + const int rest = n % block; + const int end = n - rest; + int i = 0; + __m256 scalar = _mm256_set1_ps(a); + __m256 tmp; +#define MOVE_ONE_STEP \ + tmp = _mm256_loadu_ps(x + i); \ + tmp = _mm256_mul_ps(tmp, scalar); \ + _mm256_storeu_ps(y + i, tmp) + for (i = 0; i < end; i += block) { + MOVE_ONE_STEP; + } +#undef MOVE_ONE_STEP + if (rest == 0) { + return; + } + // can not continue move step if src and dst are inplace + for (i = n - rest; i < n; ++i) { + y[i] = a * x[i]; + } +#else + vec_scal(n, a, x, y); +#endif +} + +template <> +inline void vec_scal(const int n, const float a, + const float* x, float* y) { + vec_scal(n, a, x, y); +} + +template <> +inline void vec_scal(const int n, + const float a, + const float* x, + float* y) { + // TODO(TJ): enable me + vec_scal(n, a, x, y); +} + +template +inline void vec_add_bias(const int n, const T a, const T* x, T* y) { + for (int i = 0; i < n; ++i) { + y[i] = x[i] + a; + } +} + +template <> +inline void vec_add_bias(const int n, const float a, + const float* x, float* y) { +#ifdef __AVX__ + constexpr int block = AVX_FLOAT_BLOCK; + if (n < block) { + vec_add_bias(n, a, x, y); + return; + } + const int rest = n % block; + const int end = n - rest; + int i = 0; + __m256 bias = _mm256_set1_ps(a); + __m256 tmp; +#define MOVE_ONE_STEP \ + tmp = _mm256_loadu_ps(x + i); \ + tmp = _mm256_add_ps(tmp, bias); \ + _mm256_storeu_ps(y + i, tmp) + for (i = 0; i < end; i += block) { + MOVE_ONE_STEP; + } +#undef MOVE_ONE_STEP + if (rest == 0) { + return; + } + // can not continue move step if src and dst are inplace + for (i = n - rest; i < n; ++i) { + y[i] = x[i] + a; + } +#else + vec_add_bias(n, a, x, y); +#endif +} + +template <> +inline void vec_add_bias(const int n, const float a, + const float* x, float* y) { + vec_add_bias(n, a, x, y); +} + +template <> +inline void vec_add_bias(const int n, + const float a, + const float* x, + float* y) { + // TODO(TJ): enable me + vec_add_bias(n, a, x, y); } template @@ -45,18 +198,97 @@ inline void vec_sigmoid(const int n, const T* x, T* y) { const T min = SIGMOID_THRESHOLD_MIN; const T max = SIGMOID_THRESHOLD_MAX; for (int i = 0; i < n; ++i) { - T tmp = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]); - y[i] = 1.0 / (1.0 + std::exp(-tmp)); + y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]); + y[i] = static_cast(0) - y[i]; + } + vec_exp(n, y, y); + for (int i = 0; i < n; ++i) { + y[i] = static_cast(1) / (static_cast(1) + y[i]); + } +} + +template <> +inline void vec_sigmoid(const int n, const float* x, + float* y) { +#ifdef __AVX__ + constexpr int block = AVX_FLOAT_BLOCK; + if (n < block) { + vec_sigmoid(n, x, y); + return; } + const int rest = n % block; + const int end = n - rest; + int i = 0; + __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); + __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); + __m256 zeros = _mm256_setzero_ps(); + __m256 tmp; +#define MOVE_ONE_STEP \ + tmp = _mm256_loadu_ps(x + i); \ + tmp = _mm256_max_ps(tmp, min); \ + tmp = _mm256_min_ps(tmp, max); \ + tmp = _mm256_sub_ps(zeros, tmp); \ + _mm256_storeu_ps(y + i, tmp) + for (i = 0; i < end; i += block) { + MOVE_ONE_STEP; + } +#undef MOVE_ONE_STEP + if (rest != 0) { + // can not continue move step since the src and dst address could be equal + const float xmin = SIGMOID_THRESHOLD_MIN; + const float xmax = SIGMOID_THRESHOLD_MAX; + for (i = n - rest; i < n; ++i) { + y[i] = 0.f - ((x[i] < xmin) ? xmin : ((x[i] > xmax) ? xmax : x[i])); + } + } + + vec_exp(n, y, y); + + __m256 ones = _mm256_set1_ps(1.0f); +#define MOVE_ONE_STEP \ + tmp = _mm256_loadu_ps(y + i); \ + tmp = _mm256_add_ps(ones, tmp); \ + tmp = _mm256_div_ps(ones, tmp); \ + _mm256_storeu_ps(y + i, tmp) + for (i = 0; i < end; i += block) { + MOVE_ONE_STEP; + } +#undef MOVE_ONE_STEP + if (rest == 0) { + return; + } + // can not continue move step + for (i = n - rest; i < n; ++i) { + y[i] = 1.f / (1.f + y[i]); + } +#else + vec_sigmoid(n, x, y); +#endif +} + +template <> +inline void vec_sigmoid(const int n, const float* x, + float* y) { + vec_sigmoid(n, x, y); +} + +template <> +inline void vec_sigmoid(const int n, + const float* x, + float* y) { + // TODO(TJ): enable me + vec_sigmoid(n, x, y); } template inline void vec_tanh(const int n, const T* x, T* y) { - for (int i = 0; i < n; ++i) { - y[i] = tanh(x[i]); - } + vec_scal(n, static_cast(2), x, y); + vec_sigmoid(n, y, y); + vec_scal(n, static_cast(2), y); + vec_add_bias(n, static_cast(-1), y, y); } +// TODO(TJ): make relu clip template inline void vec_relu(const int n, const T* x, T* y) { for (int i = 0; i < n; ++i) { @@ -64,24 +296,56 @@ inline void vec_relu(const int n, const T* x, T* y) { } } +template <> +inline void vec_relu(const int n, const float* x, + float* y) { +#ifdef __AVX__ + constexpr int block = AVX_FLOAT_BLOCK; + if (n < block * 4) { + vec_relu(n, x, y); + return; + } + + const int rest = n % block; + const int end = n - rest; + int i = 0; + __m256 zeros = _mm256_setzero_ps(); + __m256 tmp; +#define MOVE_ONE_STEP \ + tmp = _mm256_loadu_ps(x + i); \ + tmp = _mm256_max_ps(tmp, zeros); \ + _mm256_storeu_ps(y + i, tmp) + for (i = 0; i < end; i += block) { + MOVE_ONE_STEP; + } + if (rest == 0) { + return; + } + i = n - block; + MOVE_ONE_STEP; +#undef MOVE_ONE_STEP + +#else + vec_relu(n, x, y); +#endif +} + template <> inline void vec_relu(const int n, const float* x, float* y) { - // TODO(TJ): complete me - for (int i = 0; i < n; ++i) { - y[i] = x[i] > 0 ? x[i] : 0; - } + vec_relu(n, x, y); } template <> -inline void vec_relu(const int n, const float* x, - float* y) { - // TODO(TJ): complete me - for (int i = 0; i < n; ++i) { - y[i] = x[i] > 0 ? x[i] : 0; - } +inline void vec_relu(const int n, + const float* x, + float* y) { + // TODO(TJ): enable me + vec_relu(n, x, y); } +// TODO(TJ): optimize double of sigmoid, tanh and relu if necessary + template class VecActivations { public: @@ -96,7 +360,7 @@ class VecActivations { } else if (type == "identity" || type == "") { return vec_identity; } - PADDLE_THROW("Not support type %s.", type); + LOG(FATAL) << "Not support type: " << type; } }; diff --git a/paddle/fluid/operators/math/cpu_vec_test.cc b/paddle/fluid/operators/math/cpu_vec_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..bf6481c5ccd763189731fe1c43da7c7d9e9b15c1 --- /dev/null +++ b/paddle/fluid/operators/math/cpu_vec_test.cc @@ -0,0 +1,202 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include +#include "gflags/gflags.h" +#include "glog/logging.h" +#include "gtest/gtest.h" + +#include "paddle/fluid/operators/math/cpu_vec.h" + +inline double GetCurrentUS() { + struct timeval time; + gettimeofday(&time, NULL); + return 1e+6 * time.tv_sec + time.tv_usec; +} +constexpr int repeat = 1000; + +template +inline T _sigmoid(T x) { + const T min = SIGMOID_THRESHOLD_MIN; + const T max = SIGMOID_THRESHOLD_MAX; + T tmp = (x < min) ? min : ((x > max) ? max : x); + return static_cast(1) / (static_cast(1) + std::exp(-tmp)); +} + +template +inline T _tanh(T x) { + return static_cast(2) * _sigmoid(static_cast(2) * x) - + static_cast(1); +} + +template +void ref_sigmoid(const int n, const T* x, T* y) { + for (int i = 0; i < n; ++i) { + y[i] = _sigmoid(x[i]); + } +} + +template +void ref_tanh(const int n, const T* x, T* y) { + for (int i = 0; i < n; ++i) { + y[i] = _tanh(x[i]); + } +} +template +void ref_relu(const int n, const T* x, T* y) { + for (int i = 0; i < n; ++i) { + y[i] = x[i] > 0 ? x[i] : 0; + } +} + +template +void RandomVec(const int n, T* a) { + static unsigned int seed = 100; + std::mt19937 rng(seed++); + std::uniform_real_distribution uniform_dist(0, 1); + const T lower = static_cast(-20.f); + const T upper = static_cast(20.f); + for (int i = 0; i < n; ++i) { + a[i] = static_cast(uniform_dist(rng) * (upper - lower) + lower); + } +} + +template +void TestAndBench(const int n, std::function tgt, + std::function ref) { + std::vector x(n); + std::vector ytgt(n), yref(n); + RandomVec(n, x.data()); + + const T* x_data = x.data(); + T* ytgt_data = ytgt.data(); + T* yref_data = yref.data(); + auto st = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + tgt(n, x_data, ytgt_data); + } + auto mt = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + ref(n, x_data, yref_data); + } + auto et = GetCurrentUS(); + + VLOG(3) << "Vec size " << n << ": refer takes: " << (et - mt) / repeat + << " us, tgt takes: " << (mt - st) / repeat; + for (int i = 0; i < n; ++i) { + EXPECT_NEAR(ytgt_data[i], yref_data[i], 1e-3); + } +} + +TEST(CpuVecTest, sigmoid) { + namespace jit = paddle::platform::jit; + using namespace paddle::operators::math; // NOLINT + for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) { + TestAndBench(sz, vec_sigmoid, ref_sigmoid); + TestAndBench(sz, vec_sigmoid, ref_sigmoid); + TestAndBench(sz, vec_sigmoid, ref_sigmoid); + TestAndBench(sz, vec_sigmoid, + ref_sigmoid); + } + TestAndBench(30, vec_sigmoid, ref_sigmoid); +} + +TEST(CpuVecTest, tanh) { + namespace jit = paddle::platform::jit; + using namespace paddle::operators::math; // NOLINT + for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) { + TestAndBench(sz, vec_tanh, ref_tanh); + TestAndBench(sz, vec_tanh, ref_tanh); + TestAndBench(sz, vec_tanh, ref_tanh); + TestAndBench(sz, vec_tanh, + ref_tanh); + } + TestAndBench(30, vec_tanh, ref_tanh); +} + +TEST(CpuVecTest, relu) { + namespace jit = paddle::platform::jit; + using namespace paddle::operators::math; // NOLINT + for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) { + TestAndBench(sz, vec_relu, ref_relu); + TestAndBench(sz, vec_relu, ref_relu); + TestAndBench(sz, vec_relu, ref_relu); + TestAndBench(sz, vec_relu, + ref_relu); + } + TestAndBench(30, vec_relu, ref_relu); +} + +template +void TestInplace(const int n, std::function tgt, + std::function ref) { + std::vector x(n); + std::vector ytgt(n), yref(n); + RandomVec(n, x.data()); + + const T* x_data = x.data(); + T* yref_data = yref.data(); + T* ytgt_data = ytgt.data(); + std::memcpy(yref_data, x_data, sizeof(T) * n); + std::memcpy(ytgt_data, x_data, sizeof(T) * n); + + ref(n, yref_data, yref_data); + tgt(n, ytgt_data, ytgt_data); + + for (int i = 0; i < n; ++i) { + EXPECT_NEAR(ytgt_data[i], yref_data[i], 1e-3); + } +} + +TEST(CpuVecTest, inplace_sigmoid) { + namespace jit = paddle::platform::jit; + using namespace paddle::operators::math; // NOLINT + for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) { + TestInplace(sz, vec_sigmoid, ref_sigmoid); + TestInplace(sz, vec_sigmoid, ref_sigmoid); + TestInplace(sz, vec_sigmoid, ref_sigmoid); + TestInplace(sz, vec_sigmoid, + ref_sigmoid); + } + TestInplace(30, vec_sigmoid, ref_sigmoid); +} + +TEST(CpuVecTest, inplace_tanh) { + namespace jit = paddle::platform::jit; + using namespace paddle::operators::math; // NOLINT + for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) { + TestInplace(sz, vec_tanh, ref_tanh); + TestInplace(sz, vec_tanh, ref_tanh); + TestInplace(sz, vec_tanh, ref_tanh); + TestInplace(sz, vec_tanh, + ref_tanh); + } + TestInplace(30, vec_tanh, ref_tanh); +} + +TEST(CpuVecTest, inplace_relu) { + namespace jit = paddle::platform::jit; + using namespace paddle::operators::math; // NOLINT + for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) { + TestInplace(sz, vec_relu, ref_relu); + TestInplace(sz, vec_relu, ref_relu); + TestInplace(sz, vec_relu, ref_relu); + TestInplace(sz, vec_relu, + ref_relu); + } + TestInplace(30, vec_relu, ref_relu); +} diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index f08c0e8e345179bf198ca9d50278b7f65e03ca2c..75d3856d0dda8b5bf6a4fa11954a611bf140c9bc 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -50,7 +50,7 @@ ENDIF() # memcpy depends on device_context, here add deps individually for # avoiding cycle dependencies cc_library(device_context SRCS device_context.cc init.cc DEPS malloc - place eigen3 stringpiece cpu_helper framework_proto ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS}) + place eigen3 stringpiece cpu_helper cpu_info framework_proto ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS}) nv_test(device_context_test SRCS device_context_test.cu DEPS device_context gpu_info) cc_test(init_test SRCS init_test.cc DEPS device_context) diff --git a/paddle/fluid/platform/cpu_info.h b/paddle/fluid/platform/cpu_info.h index 5d17978dd7946596c490dc465dab51e7cf53a044..30c8fbcfce92a8b06a175ddf198cde572f72b2a4 100644 --- a/paddle/fluid/platform/cpu_info.h +++ b/paddle/fluid/platform/cpu_info.h @@ -51,7 +51,7 @@ typedef enum { } cpu_isa_t; // Instruction set architecture // May I use some instruction -inline bool MayIUse(const cpu_isa_t cpu_isa); +bool MayIUse(const cpu_isa_t cpu_isa); } // namespace jit diff --git a/paddle/fluid/platform/init.cc b/paddle/fluid/platform/init.cc index 6f1f0c4796f3bae2fb419bf103cb6c0c5489bf65..020ce4d6f59412490657767a096f1ce185287864 100644 --- a/paddle/fluid/platform/init.cc +++ b/paddle/fluid/platform/init.cc @@ -18,6 +18,7 @@ limitations under the License. */ #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/platform/cpu_helper.h" +#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/init.h" #include "paddle/fluid/platform/place.h" @@ -120,6 +121,22 @@ void InitDevices(bool init_p2p, const std::vector devices) { #ifndef PADDLE_WITH_MKLDNN platform::SetNumThreads(FLAGS_paddle_num_threads); #endif + + if (platform::jit::MayIUse(platform::jit::avx512_common)) { +#ifndef __AVX512F__ + LOG(WARNING) << "AVX512F is available, Please re-compile on local machine"; +#endif + } + if (platform::jit::MayIUse(platform::jit::avx2)) { +#ifndef __AVX2__ + LOG(WARNING) << "AVX2 is available, Please re-compile on local machine"; +#endif + } + if (platform::jit::MayIUse(platform::jit::avx)) { +#ifndef __AVX__ + LOG(WARNING) << "AVX is available, Please re-compile on local machine"; +#endif + } } void InitGLOG(const std::string &prog_name) {