/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/math/jit_kernel.h" #include #include "paddle/fluid/operators/math/jit_kernel_macro.h" #include "paddle/fluid/operators/math/jit_kernel_refer.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/macros.h" #ifdef PADDLE_WITH_XBYAK #include "paddle/fluid/operators/math/jit_code.h" #endif #ifdef __AVX__ #include #endif namespace paddle { namespace operators { namespace math { namespace jitkernel { namespace detail { #ifdef __AVX__ __m256 ExpAVX(__m256 x); #endif #ifdef __AVX2__ __m256 ExpAVX2(__m256 x); #endif } // namespace detail namespace jit = platform::jit; #ifdef __AVX__ typedef enum { kSigmoid, kRelu, kTanh, kIdentity } act_type; class AVXAct { public: virtual ~AVXAct() = default; virtual __m256 Compute(__m256 x) const = 0; }; template class AVXActImpl : public AVXAct { public: __m256 Compute(__m256 x) const override { PADDLE_THROW("Unkown type!"); } }; #define AVX_SIGMOID(isa, expisa) \ template <> \ __m256 AVXActImpl::Compute(__m256 x) const { \ __m256 ones = _mm256_set1_ps(1.0f); \ x = _mm256_max_ps(x, _mm256_set1_ps(SIGMOID_THRESHOLD_MIN)); \ x = _mm256_min_ps(x, _mm256_set1_ps(SIGMOID_THRESHOLD_MAX)); \ x = _mm256_sub_ps(_mm256_set1_ps(0.0f), x); \ x = expisa(x); \ x = _mm256_add_ps(ones, x); \ return _mm256_div_ps(ones, x); \ } #define AVX_TANH(isa, expisa) \ template <> \ __m256 AVXActImpl::Compute(__m256 x) const { \ __m256 ones = _mm256_set1_ps(1.0f); \ x = _mm256_mul_ps(_mm256_set1_ps(-2.0f), x); \ x = _mm256_min_ps(x, _mm256_set1_ps(EXP_MAX_INPUT)); \ x = expisa(x); \ x = _mm256_add_ps(ones, x); \ x = _mm256_div_ps(_mm256_set1_ps(2.0f), x); \ return _mm256_sub_ps(x, ones); \ } #define AVX_RELU(isa) \ template <> \ __m256 AVXActImpl::Compute(__m256 x) const { \ return _mm256_max_ps(x, _mm256_setzero_ps()); \ } #define AVX_IDENTITY(isa) \ template <> \ __m256 AVXActImpl::Compute(__m256 x) const { \ return x; \ } #define FOR_EACH_AVX_ISA(macro_) \ macro_(jit::avx); \ macro_(jit::avx2); \ macro_(jit::avx512f) FOR_EACH_AVX_ISA(AVX_RELU); FOR_EACH_AVX_ISA(AVX_IDENTITY); AVX_SIGMOID(jit::avx, detail::ExpAVX); AVX_TANH(jit::avx, detail::ExpAVX); #ifdef __AVX2__ AVX_SIGMOID(jit::avx2, detail::ExpAVX2); AVX_SIGMOID(jit::avx512f, detail::ExpAVX2); AVX_TANH(jit::avx2, detail::ExpAVX2); AVX_TANH(jit::avx512f, detail::ExpAVX2); #endif #undef FOR_EACH_AVX_ISA #undef AVX_IDENTITY #undef AVX_RELU #undef AVX_TANH #undef AVX_SIGMOID #endif template static std::shared_ptr> GetActKernel( const std::string& type, int n) { if (type == "sigmoid") { return std::dynamic_pointer_cast>( KernelPool::Instance().template Get>(n)); } else if (type == "relu") { return std::dynamic_pointer_cast>( KernelPool::Instance().template Get>(n)); } else if (type == "tanh") { return std::dynamic_pointer_cast>( KernelPool::Instance().template Get>(n)); } else if (type == "identity" || type == "") { return std::dynamic_pointer_cast>( KernelPool::Instance().template Get>(n)); } PADDLE_THROW("Not support type: %s", type); return nullptr; } #ifdef __AVX__ template static std::unique_ptr GetAVXAct(const std::string& type) { if (type == "sigmoid") { return std::unique_ptr(new AVXActImpl()); } else if (type == "relu") { return std::unique_ptr(new AVXActImpl()); } else if (type == "tanh") { return std::unique_ptr(new AVXActImpl()); } else if (type == "identity" || type == "") { return std::unique_ptr(new AVXActImpl()); } PADDLE_THROW("Not support type: %s", type); return nullptr; } #endif /* LSTM JitKernel */ template class LSTMKernelImpl : public LSTMKernel { public: static inline std::string name(const lstm_attr_t& attr) { PADDLE_THROW("DType should be either float or double"); } static inline bool useJIT(int d) { return false; } static inline bool useMKL(int d) { return false; } explicit LSTMKernelImpl(const lstm_attr_t& attr) : LSTMKernel() { #ifdef PADDLE_WITH_XBYAK if (useJIT(attr.d)) { size_t sz = 96 + attr.d / YMM_FLOAT_BLOCK * 84 * 8; // should change jitcode0_.reset(new gen::LSTMJitCode(false, attr, sz > 4096 ? sz : 4096)); this->ComputeCtHt = jitcode0_->getCode(); jitcode1_.reset(new gen::LSTMJitCode(true, attr, sz > 4096 ? sz : 4096)); this->ComputeC1H1 = jitcode1_->getCode(); return; } #endif this->ComputeCtHt = refer::LSTMCtHt; this->ComputeC1H1 = refer::LSTMC1H1; } #ifdef PADDLE_WITH_XBYAK private: std::unique_ptr jitcode0_{nullptr}, jitcode1_{nullptr}; #endif }; #ifdef PADDLE_WITH_XBYAK template <> bool LSTMKernelImpl::useJIT(int d) { return false; // not ready yet gen::LSTMJitCode::init(d); } #endif /* Peephole JitKernel */ template class PeepholeKernelImpl : public LSTMKernel { public: static inline std::string name(const lstm_attr_t& attr) { PADDLE_THROW("DType should be either float or double"); } static inline bool useJIT(int d) { return false; } static inline bool useMKL(int d) { return false; } explicit PeepholeKernelImpl(const lstm_attr_t& attr) : LSTMKernel() { #ifdef PADDLE_WITH_XBYAK if (useJIT(attr.d)) { size_t sz = 96 + attr.d / YMM_FLOAT_BLOCK * 84 * 8; // should change jitcode0_.reset(new gen::LSTMJitCode(false, attr, sz > 4096 ? sz : 4096)); this->ComputeCtHt = jitcode0_->getCode(); jitcode1_.reset(new gen::LSTMJitCode(true, attr, sz > 4096 ? sz : 4096)); this->ComputeC1H1 = jitcode1_->getCode(); return; } #endif this->ComputeCtHt = refer::LSTMCtHt; this->ComputeC1H1 = refer::LSTMC1H1; } #ifdef PADDLE_WITH_XBYAK private: std::unique_ptr jitcode0_{nullptr}, jitcode1_{nullptr}; #endif }; #ifdef PADDLE_WITH_XBYAK template <> bool PeepholeKernelImpl::useJIT(int d) { return false; // peephole jitcode not ready yet } #endif #define JITKERNEL_DEFINE_NAME_LSTM(ker_key, ker_class) \ template <> \ std::string ker_class##Impl::name(const lstm_attr_t& attr) { \ std::string key(#ker_key "f"); \ key += (attr.act_gate + attr.act_cand + attr.act_cell + \ (attr.use_peephole ? "p" : "n")); \ if (useJIT(attr.d)) { \ /* only jit code need record d*/ \ return key + "jit" + std::to_string(attr.d); \ } else if (useMKL(attr.d)) { \ return key + "mkl"; \ } else { \ return key + "any"; \ } \ } \ template <> \ std::string ker_class##Impl::name(const lstm_attr_t& attr) { \ std::string key(#ker_key "d"); \ /* jit code do not support double yet*/ \ if (useMKL(attr.d)) { \ return key + "mkl"; \ } else { \ return key + "any"; \ } \ } #define JITKERNEL_DECLARE_LSTM(ker_class, ker_dtype) \ template <> \ std::shared_ptr> \ KernelPool::Get, const lstm_attr_t&>( \ const lstm_attr_t& attr) #define JITKERNEL_FIND_KEY_LSTM(ker_class, ker_dtype) \ std::string key = ker_class##Impl::name(attr) #define JITKERNEL_LSTM_IMPL(ker, dtype) \ if (attr.use_peephole) { \ p = std::dynamic_pointer_cast>( \ std::make_shared>(attr)); \ } else { \ p = std::dynamic_pointer_cast>( \ std::make_shared>(attr)); \ } REGISTER_JITKERNEL_ARGS(lstm, LSTMKernel, JITKERNEL_DEFINE_NAME_LSTM, JITKERNEL_DECLARE_LSTM, JITKERNEL_FIND_KEY_LSTM, JITKERNEL_LSTM_IMPL); /* GRU JitKernel */ template class GRUKernelImpl : public GRUKernel { public: explicit GRUKernelImpl(const std::string& act_gate, const std::string& act_state, int d) : GRUKernel() { d_ = d; d2_ = d * 2; act_gate_d2_ = GetActKernel(act_gate, d2_); act_gate_d_ = GetActKernel(act_gate, d); act_state_d_ = GetActKernel(act_state, d); vmul_d_ = KernelPool::Instance().template Get>(d); } void ComputeH1(T* gates, T* ht) const override { act_gate_d_->Compute(gates, gates, d_); act_state_d_->Compute(gates + d2_, gates + d2_, d_); vmul_d_->Compute(gates, gates + d2_, ht, d_); } void ComputeHtPart1(T* gates, const T* ht_1, T* ht) const override { // W: {W_update, W_reset; W_state} act_gate_d2_->Compute(gates, gates, d2_); vmul_d_->Compute(ht_1, gates + d_, ht, d_); } void ComputeHtPart2(T* gates, const T* ht_1, T* ht) const override { T* y = gates + d2_; act_state_d_->Compute(y, y, d_); // out = zt*ht~ + (1-zt)*ht_1 for (int i = 0; i < d_; ++i) { ht[i] = gates[i] * y[i] + (static_cast(1) - gates[i]) * ht_1[i]; } } private: int d_, d2_; std::shared_ptr> act_gate_d2_, act_gate_d_, act_state_d_; std::shared_ptr> vmul_d_; #ifdef __AVX__ std::unique_ptr avx_act_gate_, avx_act_state_; #endif }; #define INTRI8_FLOAT(isa) \ template <> \ GRUKernelImpl::GRUKernelImpl( \ const std::string& act_gate, const std::string& act_state, int d) \ : GRUKernel() { \ avx_act_gate_ = GetAVXAct(act_gate); \ avx_act_state_ = GetAVXAct(act_state); \ } \ template <> \ void GRUKernelImpl::ComputeH1(float* gates, float* ht) \ const { \ __m256 u, s; \ /* W: {W_update, W_reset; W_state} */ \ u = _mm256_loadu_ps(gates); \ s = _mm256_loadu_ps(gates + 16); \ s = _mm256_mul_ps(avx_act_gate_->Compute(u), avx_act_state_->Compute(s)); \ _mm256_storeu_ps(ht, s); \ } \ template <> \ void GRUKernelImpl::ComputeHtPart1( \ float* gates, const float* ht_1, float* ht) const { \ /* not exactly equal the any implementation */ \ __m256 r, ht0; \ r = _mm256_loadu_ps(gates + 8); \ ht0 = _mm256_loadu_ps(ht_1); \ r = _mm256_mul_ps(avx_act_gate_->Compute(r), ht0); \ _mm256_storeu_ps(ht, r); \ } \ template <> \ void GRUKernelImpl::ComputeHtPart2( \ float* gates, const float* ht_1, float* ht) const { \ /* not exactly equal the any implementation */ \ __m256 u, s, ht0; \ u = _mm256_loadu_ps(gates); \ s = _mm256_loadu_ps(gates + 16); \ ht0 = _mm256_loadu_ps(ht_1); \ u = avx_act_gate_->Compute(u); \ s = _mm256_mul_ps(u, avx_act_state_->Compute(s)); \ u = _mm256_sub_ps(_mm256_set1_ps(1.f), u); \ u = _mm256_mul_ps(u, ht0); \ u = _mm256_add_ps(s, u); \ _mm256_storeu_ps(ht, u); \ } #ifdef __AVX__ INTRI8_FLOAT(jit::avx); #endif #ifdef __AVX2__ INTRI8_FLOAT(jit::avx2); #endif #ifdef __AVX512F__ INTRI8_FLOAT(jit::avx512f); #endif #define JITKERNEL_DECLARE_GRU(ker_class, ker_dtype) \ template <> \ std::shared_ptr> KernelPool::Get< \ GRUKernel, const std::string&, const std::string&, int>( \ const std::string& act_gate, const std::string& act_state, int d) #define JITKERNEL_KEY_GRU(ker_key, dtype_key) \ #ker_key #dtype_key + std::to_string(d) + act_gate + act_state #define JITKERNEL_NEW_GRU_IMPL(ker, dtype, isa, k) \ p = std::dynamic_pointer_cast>( \ std::make_shared>(act_gate, act_state, d)); REGISTER_JITKERNEL_ARGS_DEPRECATED(gru, GRUKernel, JITKERNEL_DECLARE_GRU, JITKERNEL_KEY_GRU, JITKERNEL_NEW_GRU_IMPL); #undef INTRI8_FLOAT #undef JITKERNEL_NEW_GRU_IMPL #undef JITKERNEL_KEY_GRU #undef JITKERNEL_DECLARE_GRU } // namespace jitkernel } // namespace math } // namespace operators } // namespace paddle