From f2adaf1c3ec4774955ec7f52b9b3d44e02684504 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Mon, 8 Oct 2018 22:18:31 +0800 Subject: [PATCH] add vrelu and lstm kernel test=develop --- paddle/fluid/operators/math/jit_kernel.cc | 17 --- paddle/fluid/operators/math/jit_kernel.h | 33 +++-- .../fluid/operators/math/jit_kernel_blas.cc | 109 +++++++++++++++ paddle/fluid/operators/math/jit_kernel_exp.cc | 1 + .../fluid/operators/math/jit_kernel_lstm.cc | 130 +++++++++++------- .../fluid/operators/math/jit_kernel_test.cc | 54 ++++++++ 6 files changed, 269 insertions(+), 75 deletions(-) diff --git a/paddle/fluid/operators/math/jit_kernel.cc b/paddle/fluid/operators/math/jit_kernel.cc index 18a58cbea7..54292cd710 100644 --- a/paddle/fluid/operators/math/jit_kernel.cc +++ b/paddle/fluid/operators/math/jit_kernel.cc @@ -35,23 +35,6 @@ std::shared_ptr KernelPool::Get(const std::string& key) const { return kers_.at(key); } -template <> -std::shared_ptr> -KernelPool::Get, int, const std::string&, const std::string&, - const std::string&>(int d, const std::string& act_gate, - const std::string& act_cand, - const std::string& act_cell) { - std::string key = - "lstmf" + std::to_string(d) + act_gate + act_cand + act_cell; - if (kers_.find(key) == kers_.end()) { - auto p = - std::make_shared>(d, act_gate, act_cand, act_cell); - kers_.insert({key, std::dynamic_pointer_cast(p)}); - return p; - } - return std::dynamic_pointer_cast>(kers_.at(key)); -} - } // namespace jitkernel } // namespace math } // namespace operators diff --git a/paddle/fluid/operators/math/jit_kernel.h b/paddle/fluid/operators/math/jit_kernel.h index 173cc36887..6edfdf22d1 100644 --- a/paddle/fluid/operators/math/jit_kernel.h +++ b/paddle/fluid/operators/math/jit_kernel.h @@ -87,36 +87,45 @@ class VAddBiasKernel : public Kernel { }; template -class VExpKernel : public Kernel { +class VActKernel : public Kernel { public: virtual void Compute(const T *x, T *y) const = 0; }; template -class VSigmoidKernel : public Kernel { +class VReluKernel : public VActKernel { public: virtual void Compute(const T *x, T *y) const = 0; }; template -class VTanhKernel : public Kernel { +class VIdentityKernel : public VActKernel { public: virtual void Compute(const T *x, T *y) const = 0; }; template -class LSTMKernel : public Kernel { +class VExpKernel : public VActKernel { public: - explicit LSTMKernel(int d, const std::string &act_gate, - const std::string &act_cand, const std::string &act_cell); + virtual void Compute(const T *x, T *y) const = 0; +}; - void (*jit_ker)(T *, const T *, T *, T *); - std::function ComputeCtHt, ComputeCtHt_NoC0H0; +template +class VSigmoidKernel : public VActKernel { + public: + virtual void Compute(const T *x, T *y) const = 0; +}; - private: - int d_, d2_, d3_; - std::function act_gate_, act_cell_, - act_cand_; +template +class VTanhKernel : public VActKernel { + public: + virtual void Compute(const T *x, T *y) const = 0; +}; + +template +class LSTMKernel : public Kernel { + public: + virtual void ComputeCtHt(T *gates, const T *ct_1, T *ct, T *ht) const = 0; }; } // namespace jitkernel diff --git a/paddle/fluid/operators/math/jit_kernel_blas.cc b/paddle/fluid/operators/math/jit_kernel_blas.cc index 4ea1a8cd5c..0f9ea533fc 100644 --- a/paddle/fluid/operators/math/jit_kernel_blas.cc +++ b/paddle/fluid/operators/math/jit_kernel_blas.cc @@ -266,15 +266,124 @@ INTRI16_FLOAT(jit::avx512f); #endif // TODO(TJ): eq16 test and complete avx512 +#undef INTRI8_FLOAT +#undef INTRI16_FLOAT + +/* VRelu JitKernel */ +template +class VReluKernelImpl : public VReluKernel { + public: + explicit VReluKernelImpl(int d) : VReluKernel() { this->num_ = d; } + void Compute(const T* x, T* y) const override { + for (int i = 0; i < this->num_; ++i) { + y[i] = x[i] > 0 ? x[i] : 0; + } + } +}; + +#define INTRI8_FLOAT(isa) \ + template <> \ + void VReluKernelImpl::Compute(const float* x, float* y) \ + const { \ + __m256 tmp = _mm256_loadu_ps(x); \ + tmp = _mm256_max_ps(tmp, _mm256_setzero_ps()); \ + _mm256_storeu_ps(y, tmp); \ + } + +#define INTRI16_FLOAT(isa) \ + template <> \ + void VReluKernelImpl::Compute(const float* x, float* y) \ + const { \ + __m256 zeros = _mm256_setzero_ps(); \ + __m256 tmp0 = _mm256_loadu_ps(x); \ + __m256 tmp1 = _mm256_loadu_ps(x + 8); \ + tmp0 = _mm256_max_ps(tmp0, zeros); \ + tmp1 = _mm256_max_ps(tmp1, zeros); \ + _mm256_storeu_ps(y, tmp0); \ + _mm256_storeu_ps(y + 8, tmp1); \ + } + +#define INTRI_GT8LT16_FLOAT(isa) \ + template <> \ + VReluKernelImpl::VReluKernelImpl(int d) \ + : VReluKernel() { \ + this->num_ = d; \ + this->end_ = AVX_FLOAT_BLOCK; \ + this->rest_ = d - AVX_FLOAT_BLOCK; \ + } \ + template <> \ + void VReluKernelImpl::Compute(const float* x, \ + float* y) const { \ + __m256 zeros = _mm256_setzero_ps(); \ + __m256 tmp0 = _mm256_loadu_ps(x); \ + __m256 tmp1 = _mm256_loadu_ps(x + this->rest_); \ + tmp0 = _mm256_max_ps(tmp0, zeros); \ + tmp1 = _mm256_max_ps(tmp1, zeros); \ + _mm256_storeu_ps(y, tmp0); \ + _mm256_storeu_ps(y + this->rest_, tmp1); \ + } + +#define INTRI_GT16_FLOAT(isa) \ + template <> \ + VReluKernelImpl::VReluKernelImpl(int d) \ + : VReluKernel() { \ + this->num_ = d; \ + this->end_ = d - d % AVX_FLOAT_BLOCK; \ + this->rest_ = d - AVX_FLOAT_BLOCK; \ + } \ + template <> \ + void VReluKernelImpl::Compute(const float* x, float* y) \ + const { \ + __m256 zeros = _mm256_setzero_ps(); \ + for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \ + __m256 tmp = _mm256_loadu_ps(x + i); \ + tmp = _mm256_max_ps(tmp, zeros); \ + _mm256_storeu_ps(y + i, tmp); \ + } \ + __m256 tmp = _mm256_loadu_ps(x + this->rest_); \ + tmp = _mm256_max_ps(tmp, zeros); \ + _mm256_storeu_ps(y + this->rest_, tmp); \ + } + +#ifdef __AVX__ +INTRI8_FLOAT(jit::avx); +INTRI16_FLOAT(jit::avx); +INTRI_GT8LT16_FLOAT(jit::avx); +INTRI_GT16_FLOAT(jit::avx); +#endif +#ifdef __AVX2__ +INTRI8_FLOAT(jit::avx2); +INTRI16_FLOAT(jit::avx2); +INTRI_GT8LT16_FLOAT(jit::avx2); +INTRI_GT16_FLOAT(jit::avx2); +#endif +#ifdef __AVX512F__ +// TODO(TJ): refine avx512 +INTRI8_FLOAT(jit::avx512f); +INTRI16_FLOAT(jit::avx512f); +INTRI_GT8LT16_FLOAT(jit::avx512f); +INTRI_GT16_FLOAT(jit::avx512f); +#endif + #undef INTRI8_FLOAT #undef INTRI16_FLOAT #undef INTRI_GT8LT16_FLOAT #undef INTRI_GT16_FLOAT +/* An empty JitKernel */ +template +class VIdentityKernelImpl : public VIdentityKernel { + public: + explicit VIdentityKernelImpl(int d) : VIdentityKernel() { this->num_ = d; } + void Compute(const T* x, T* y) const override {} +}; + REGISTER_JITKERNEL(vmul, VMulKernel); REGISTER_JITKERNEL(vadd, VAddKernel); REGISTER_JITKERNEL(vscal, VScalKernel); REGISTER_JITKERNEL(vaddb, VAddBiasKernel); +REGISTER_JITKERNEL(vrelu, VReluKernel); +REGISTER_JITKERNEL(videntity, VIdentityKernel); } // namespace jitkernel } // namespace math diff --git a/paddle/fluid/operators/math/jit_kernel_exp.cc b/paddle/fluid/operators/math/jit_kernel_exp.cc index 7e28a3a187..b62e130c43 100644 --- a/paddle/fluid/operators/math/jit_kernel_exp.cc +++ b/paddle/fluid/operators/math/jit_kernel_exp.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/math/jit_kernel.h" +#include // for exp #include #include "paddle/fluid/operators/math/jit_kernel_macro.h" #ifdef PADDLE_WITH_MKLML diff --git a/paddle/fluid/operators/math/jit_kernel_lstm.cc b/paddle/fluid/operators/math/jit_kernel_lstm.cc index 895784a4fa..210b229b28 100644 --- a/paddle/fluid/operators/math/jit_kernel_lstm.cc +++ b/paddle/fluid/operators/math/jit_kernel_lstm.cc @@ -13,9 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/math/jit_kernel.h" -#include #include -#include "paddle/fluid/operators/math/cpu_vec.h" +#include "paddle/fluid/operators/math/jit_kernel_macro.h" +#include "paddle/fluid/platform/enforce.h" + +#ifdef __AVX__ +#include +#endif namespace paddle { namespace operators { @@ -24,51 +28,85 @@ namespace jitkernel { namespace jit = platform::jit; -template <> -LSTMKernel::LSTMKernel(int d, const std::string& act_gate_str, - const std::string& act_cand_str, - const std::string& act_cell_str) - : Kernel(), d_(d) { - d2_ = d * 2; - d3_ = d * 3; - if (platform::jit::MayIUse(platform::jit::avx512f)) { - math::VecActivations act_functor; - act_gate_ = act_functor(act_gate_str); - act_cell_ = act_functor(act_cell_str); - act_cand_ = act_functor(act_cand_str); - } else if (platform::jit::MayIUse(platform::jit::avx2)) { - math::VecActivations act_functor; - act_gate_ = act_functor(act_gate_str); - act_cell_ = act_functor(act_cell_str); - act_cand_ = act_functor(act_cand_str); - } else if (platform::jit::MayIUse(platform::jit::avx)) { - math::VecActivations act_functor; - act_gate_ = act_functor(act_gate_str); - act_cell_ = act_functor(act_cell_str); - act_cand_ = act_functor(act_cand_str); - // ComputeCtHt = [&](float*gates,const float*ct_1,float*ct, float*ht) { - // // gates: W_ch, W_ih, W_fh, W_oh - // act_gate(d3_, gates + d_, gates + d_); - - // /* C_t = C_t-1 * fgated + cand_gated * igated */ - // act_cand(d_, gates, gates); - // blas.VMUL(d_, gates, gates + d_, gates + d_); - // blas.VMUL(d_, ct_1, gates + d2_, gates + d2_); - // blas.VADD(d_, gates + d_, gates + d2_, ct); - - // /* H_t = act_cell(C_t) * ogated */ - // act_cell(d_, ct, gates + d2_); - // blas.VMUL(d_, gates + d2_, gates + d3_, ht) - // GET_Ct(ct_1, gates, ct); - // GET_Ht(ct, gates, ht); - // }; - } else { - math::VecActivations act_functor; - act_gate_ = act_functor(act_gate_str); - act_cell_ = act_functor(act_cell_str); - act_cand_ = act_functor(act_cand_str); +/* LSTM JitKernel */ +template +class LSTMKernelImpl : public LSTMKernel { + public: + explicit LSTMKernelImpl(int d, const std::string& act_gate, + const std::string& act_cand, + const std::string& act_cell) + : LSTMKernel() { + d_ = d; + d2_ = d * 2; + d3_ = d * 3; + auto GetActKernel = [&](const std::string& type, + int n) -> std::shared_ptr> { + if (type == "sigmoid") { + return std::dynamic_pointer_cast>( + KernelPool::Instance().template Get>(n)); + } else if (type == "relu") { + return std::dynamic_pointer_cast>( + KernelPool::Instance().template Get>(n)); + } else if (type == "tanh") { + return std::dynamic_pointer_cast>( + KernelPool::Instance().template Get>(n)); + } else if (type == "identity" || type == "") { + return std::dynamic_pointer_cast>( + KernelPool::Instance().template Get>(n)); + } + PADDLE_THROW("Not support type: %s", type); + }; + act_gate_3d_ = GetActKernel(act_gate, d * 3); + act_cand_d_ = GetActKernel(act_cand, d); + act_cell_d_ = GetActKernel(act_cell, d); + vmul_d_ = KernelPool::Instance().template Get>(d); + vadd_d_ = KernelPool::Instance().template Get>(d); + } + + void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht) const override { + // gates: W_ch, W_ih, W_fh, W_oh + act_gate_3d_->Compute(gates + d_, gates + d_); + + /* C_t = C_t-1 * fgated + cand_gated * igated */ + act_cand_d_->Compute(gates, gates); + vmul_d_->Compute(gates, gates + d_, gates + d_); + vmul_d_->Compute(ct_1, gates + d2_, gates + d2_); + vadd_d_->Compute(gates + d_, gates + d2_, ct); + + /* H_t = act_cell(C_t) * ogated */ + act_cell_d_->Compute(ct, gates + d2_); + vmul_d_->Compute(gates + d2_, gates + d3_, ht); } -} + + private: + int d_, d2_, d3_; + std::shared_ptr> act_gate_3d_, act_cand_d_, act_cell_d_; + std::shared_ptr> vmul_d_; + std::shared_ptr> vadd_d_; +}; + +#define JITKERNEL_DECLARE_LSTM(ker_class, ker_dtype) \ + template <> \ + std::shared_ptr> \ + KernelPool::Get, int, const std::string&, \ + const std::string&, const std::string&>( \ + int d, const std::string& act_gate, const std::string& act_cand, \ + const std::string& act_cell) + +#define JITKERNEL_KEY_LSTM(ker_key, dtype_key) \ + #ker_key #dtype_key + std::to_string(d) + act_gate + act_cand + act_cell + +#define JITKERNEL_NEW_LSTM_IMPL(ker, dtype, isa, k) \ + p = std::dynamic_pointer_cast>( \ + std::make_shared>(d, act_gate, act_cand, \ + act_cell)) + +REGISTER_JITKERNEL_ARGS(lstm, LSTMKernel, JITKERNEL_DECLARE_LSTM, + JITKERNEL_KEY_LSTM, JITKERNEL_NEW_LSTM_IMPL); + +#undef JITKERNEL_DECLARE_LSTM +#undef JITKERNEL_KEY_LSTM +#undef JITKERNEL_NEW_LSTM_IMPL } // namespace jitkernel } // namespace math diff --git a/paddle/fluid/operators/math/jit_kernel_test.cc b/paddle/fluid/operators/math/jit_kernel_test.cc index 5e9e5c5b29..d2de4545ce 100644 --- a/paddle/fluid/operators/math/jit_kernel_test.cc +++ b/paddle/fluid/operators/math/jit_kernel_test.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/jit_kernel.h" #include +#include // for exp #include // for memcpy #include #include @@ -48,6 +49,59 @@ void RandomVec(const int n, T* a, const T lower = static_cast(-20.f), } } +void vrelu_ref(const int n, const float* x, float* y) { + for (int i = 0; i < n; ++i) { + y[i] = x[i] > 0.f ? x[i] : 0.f; + } +} + +#if defined __AVX__ || defined __AVX2__ +void vrelu_intri8(const int n, const float* x, float* y) { + __m256 tmp = _mm256_loadu_ps(x); + tmp = _mm256_max_ps(tmp, _mm256_setzero_ps()); + _mm256_storeu_ps(y, tmp); +} +#endif + +TEST(JitKernel, vrelu) { + namespace jit = paddle::operators::math::jitkernel; + for (int d : {7, 8, 15, 16, 30, 256, 512}) { + std::vector x(d); + std::vector zref(d), ztgt(d); + RandomVec(d, x.data(), -10.f, 1.f); + const auto& ker = + jit::KernelPool::Instance().template Get>(d); + const float* x_data = x.data(); + float* ztgt_data = ztgt.data(); + float* zref_data = zref.data(); + auto trefs = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + vrelu_ref(d, x_data, zref_data); + } + auto trefe = GetCurrentUS(); +#if defined __AVX__ || defined __AVX2__ + if (d == 8) { + auto si0 = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + vrelu_intri8(d, x_data, zref_data); + } + auto si1 = GetCurrentUS(); + VLOG(3) << "Vec size 8 intr takes: " << (si1 - si0) / repeat; + } +#endif + auto ttgts = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + ker->Compute(x_data, ztgt_data); + } + auto ttgte = GetCurrentUS(); + VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat + << " us, tgt takes: " << (ttgte - ttgts) / repeat; + for (int i = 0; i < d; ++i) { + EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); + } + } +} + void vaddbias_ref(const int n, const float a, const float* x, float* y) { for (int i = 0; i < n; ++i) { y[i] = x[i] + a; -- GitLab