From d53c4756ad146a442c9bfcf6ae850d98e4835f80 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Wed, 19 Dec 2018 12:47:55 +0000 Subject: [PATCH] clean code and remove unused files test=develop --- paddle/fluid/operators/jit/README.md | 34 +- paddle/fluid/operators/jit/benchmark.cc | 5 +- paddle/fluid/operators/jit/helper.h | 2 +- paddle/fluid/operators/jit/kernel_base.h | 13 +- .../jit/more/intrinsic/crf_decoding.cc | 2 +- .../jit/more/intrinsic/crf_decoding.h | 6 +- .../jit/more/intrinsic/layer_norm.cc | 2 +- .../operators/jit/more/intrinsic/layer_norm.h | 5 +- paddle/fluid/operators/jit/more/mix/mix.cc | 14 +- paddle/fluid/operators/jit/more/mix/mix.h | 11 +- paddle/fluid/operators/jit/more/mkl/mkl.cc | 20 +- paddle/fluid/operators/jit/more/mkl/mkl.h | 13 +- paddle/fluid/operators/jit/test.cc | 4 +- paddle/fluid/operators/math/jit_code.cc | 334 -------- paddle/fluid/operators/math/jit_code.h | 532 ------------- paddle/fluid/operators/math/jit_gen.cc | 90 --- paddle/fluid/operators/math/jit_gen.h | 80 -- paddle/fluid/operators/math/jit_kernel.cc | 39 - paddle/fluid/operators/math/jit_kernel.h | 157 ---- .../fluid/operators/math/jit_kernel_blas.cc | 346 -------- .../operators/math/jit_kernel_crf_decode.cc | 291 ------- paddle/fluid/operators/math/jit_kernel_exp.cc | 195 ----- paddle/fluid/operators/math/jit_kernel_impl.h | 34 - .../operators/math/jit_kernel_layer_norm.cc | 239 ------ .../fluid/operators/math/jit_kernel_macro.h | 179 ----- .../fluid/operators/math/jit_kernel_refer.h | 29 - paddle/fluid/operators/math/jit_kernel_rnn.cc | 263 ------- .../fluid/operators/math/jit_kernel_test.cc | 742 ------------------ 28 files changed, 73 insertions(+), 3608 deletions(-) delete mode 100644 paddle/fluid/operators/math/jit_code.cc delete mode 100644 paddle/fluid/operators/math/jit_code.h delete mode 100644 paddle/fluid/operators/math/jit_gen.cc delete mode 100644 paddle/fluid/operators/math/jit_gen.h delete mode 100644 paddle/fluid/operators/math/jit_kernel.cc delete mode 100644 paddle/fluid/operators/math/jit_kernel.h delete mode 100644 paddle/fluid/operators/math/jit_kernel_blas.cc delete mode 100644 paddle/fluid/operators/math/jit_kernel_crf_decode.cc delete mode 100644 paddle/fluid/operators/math/jit_kernel_exp.cc delete mode 100644 paddle/fluid/operators/math/jit_kernel_impl.h delete mode 100644 paddle/fluid/operators/math/jit_kernel_layer_norm.cc delete mode 100644 paddle/fluid/operators/math/jit_kernel_macro.h delete mode 100644 paddle/fluid/operators/math/jit_kernel_refer.h delete mode 100644 paddle/fluid/operators/math/jit_kernel_rnn.cc delete mode 100644 paddle/fluid/operators/math/jit_kernel_test.cc diff --git a/paddle/fluid/operators/jit/README.md b/paddle/fluid/operators/jit/README.md index 1264bc96ee6..ae504c710d8 100644 --- a/paddle/fluid/operators/jit/README.md +++ b/paddle/fluid/operators/jit/README.md @@ -1,7 +1,8 @@ # JIT Kernel 结合函数模板和JIT生成需要的kernel函数。 -这里的kernel是比Operator中kernel更小级别的算子单元,更侧重的是在不同硬件上的性能。 +这里的kernel是比Operator中kernel更小级别的算子单元,更侧重的是在不同硬件上的性能。可以有多重第三方库的实现,每种实现有自己的`UseMe`函数负责什么条件下可以被调用。 +这里实现的函数可以非常细粒度的函数方法,比如Vector mul, 也可以是一个复杂的逻辑比如LSTM等。复杂的逻辑也可以由自己的底层函数拼接而成。 目前仅支持CPU上的高性能计算。 ## 目录结构 @@ -21,6 +22,8 @@ PaddlePaddle/Paddle/paddle/fluid/ │ │ └── ... │ ├── mkldnn/ │ │ └── ... + │ ├── mix/ + │ │ └── ... │ ├── intrinsic/ │ │ └── ... │ └── openblas/ @@ -29,28 +32,35 @@ PaddlePaddle/Paddle/paddle/fluid/ └── ... ``` -基础class都的根目录下,根目录下包括jitcode,more和refer。每个目录下都是一种实现,每种kernel算子都需要有reference的实现,其他的都是可选的。 -- jitcode: 代表使用jit生成的code,需要依赖xbyak。他关心的是性能。 -- refer:代表reference的实现,每种kernel算子都需要有在CPU上的reference的实现,他主要关心的算法逻辑。 -- more: 下面可以放入跟多实现,包括mkl,mkldnn,openblas等,也可以是自身已有的kernel组合。 +基本类的定义都放在根目录下,根目录下包括gen,more和refer三个目录。每个目录下都是一种或者多种实现,每种kernel算子都需要有reference的实现,用作单元测试的基准,其他的实现都是可选的。 +- gen: 代表使用jit生成的code,需要依赖xbyak库。该实现最关心的就是性能。 +- refer: 代表reference的实现,每种kernel算子都需要有在CPU上的reference的实现,他主要关心的算法逻辑的正确性。 +- more: 下面可以放入跟多实现,可以包括mkl,mkldnn,intrinsic,openblas等,也可以是自身已有的kernel组合。 ## 动态获取 -提供一个get方法,根据kernel类别获取,每种实现都有自己的使用范围,根据范围动态和当前条件选择需要的kernel函数。 +提供一个`jit::Get`方法,根据kernel类别获取,每种实现都有自己的使用范围,根据范围动态和当前条件选择需要的kernel函数。 ## 测试 - 逻辑测试 所有实现都要与refer的code对比,需要满足精度要求, 包括float和double的数据类型 - 性能测试 - 所有实现的性能对比,并且与最终的`jit::Get`方法对比,该方法拿到的性能需要是最好的。 + 所有实现的性能对比,并且与最终的`jit::Get`方法对比,该方法拿到的性能需要在各种条件下都是最好的。 # 如何添加新的算子 - 在`KernelType` 中添加 `your_key` . -- 实现Reference 的逻辑,每个jitkernel的Reference 实现是必须的。不要依赖任何第三方库。并在`refer/CmakeLists.txt`中`USE_JITKERNEL_REFER(your_key)`. -- (optional) 实现更多的算法在`more`目录下,可以依赖mkl,openblas,或者mkldnn等第三方库。 -- (optional) 实现基于Xbyak的生成code,在`gen`目下。 jitcode需要实现自己的`JitCodeCreator`,并注册在KernelType上。 +- 实现Reference 的逻辑,这个是必须是在CPU上的实现,并且不能依赖任何第三方库。实现后在`refer/CmakeLists.txt`中添加`USE_JITKERNEL_REFER(your_key)`来使用该kernel. +- (optional) 实现更多的算法在`more`目录下,可以依赖mkl,intrinsic或者mkldnn等第三方库。 +- (optional) 实现基于Xbyak的生成code,在`gen`目下。 jitcode需要实现自己的`JitCodeCreator`,并注册在与refer相同的`KernelType`上。 - 必要时可以添加新的`KernelTuples`,可以参考`XYZNTuples`,新加的Attr类型需要特例化`JitCodeKey`方法。 -- 添加unit test,需要测试float和double -- 添加benchmark确保get得到的速度是最快。 +- 在`test.cc`中添加unit test,至少需要测试`float`和`double`两种数据类型,如有必要需要支持额外的数据类型,比如`int8`的相关函数。 +- 在`benchmark.cc`中添加相应的性能对比,同一种kernel需要对比所有实现,并且确保`jit::Get`得到的实现一直是速度最快的。 + +# 优点 +- 统一的Get方法,接口简单。 +- 同一套逻辑可以有多套实现,可以依赖多套第三方库,互不影响。 +- 目录结构清晰,不会在某个文件中有多个宏定义,导致的可读性差问题。 +- 优化方便,可以直接针对某种属性针对性优化,并不影响其他属性下的性能。 +- 可以支持多种平台,包括Linux,Mac 和 Windows,至少可以保证每种平台都可以正常work。后期也可以针对不同平台有针对的优化。框架层面可以使用统一接口,不必关心底层实现。 diff --git a/paddle/fluid/operators/jit/benchmark.cc b/paddle/fluid/operators/jit/benchmark.cc index 4e5d530251e..1fae600500c 100644 --- a/paddle/fluid/operators/jit/benchmark.cc +++ b/paddle/fluid/operators/jit/benchmark.cc @@ -93,10 +93,11 @@ void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) { if (iter != pool.end()) { auto& impls = iter->second; for (auto& impl : impls) { - auto i = dynamic_cast*>(impl.get()); + auto i = dynamic_cast*>(impl.get()); if (i && i->UseMe(attr)) { auto more = i->GetFunc(); - infos.push_back(std::make_pair("More", benchmark(more, args...))); + infos.push_back( + std::make_pair(i->ImplType(), benchmark(more, args...))); } } } diff --git a/paddle/fluid/operators/jit/helper.h b/paddle/fluid/operators/jit/helper.h index 38bc7cd8e89..412df86aa1c 100644 --- a/paddle/fluid/operators/jit/helper.h +++ b/paddle/fluid/operators/jit/helper.h @@ -107,7 +107,7 @@ typename KernelTuples::func_type Get( if (iter != pool.end()) { auto& impls = iter->second; for (auto& impl : impls) { - auto i = dynamic_cast*>(impl.get()); + auto i = dynamic_cast*>(impl.get()); if (i && i->UseMe(attr)) { return i->GetFunc(); } diff --git a/paddle/fluid/operators/jit/kernel_base.h b/paddle/fluid/operators/jit/kernel_base.h index e8b73bd83cd..ae8f3d68fa3 100644 --- a/paddle/fluid/operators/jit/kernel_base.h +++ b/paddle/fluid/operators/jit/kernel_base.h @@ -144,28 +144,27 @@ class Kernel { }; template -class KernelImpl : public Kernel { - // TODO(TJ): rename KernelImpl to KernelMore which seems only used in more - // and add name interface for more implements easy for debug +class KernelMore : public Kernel { public: using T = typename KernelTuples::data_type; using Func = typename KernelTuples::func_type; using Attr = typename KernelTuples::attr_type; virtual Func GetFunc() const { return func; } - // TODO(TJ): const &attr - virtual bool UseMe(Attr attr) const = 0; + virtual bool UseMe(const Attr& attr) const = 0; + virtual const char* ImplType() const = 0; protected: Func func{nullptr}; }; template -class ReferKernel : public KernelImpl { +class ReferKernel : public KernelMore { public: // Refer code can always be used - bool UseMe(typename KernelTuples::attr_type attr) const override { + bool UseMe(const typename KernelTuples::attr_type& attr) const override { return true; } + const char* ImplType() const override { return "Refer"; } }; } // namespace jit diff --git a/paddle/fluid/operators/jit/more/intrinsic/crf_decoding.cc b/paddle/fluid/operators/jit/more/intrinsic/crf_decoding.cc index 17b5eaf13df..f06892a4077 100644 --- a/paddle/fluid/operators/jit/more/intrinsic/crf_decoding.cc +++ b/paddle/fluid/operators/jit/more/intrinsic/crf_decoding.cc @@ -156,7 +156,7 @@ void CRFDecoding(const int seq_len, const float* x, const float* w, } } -bool CRFDecodingKernel::UseMe(int d) const { +bool CRFDecodingKernel::UseMe(const int& d) const { return platform::MayIUse(platform::avx); } diff --git a/paddle/fluid/operators/jit/more/intrinsic/crf_decoding.h b/paddle/fluid/operators/jit/more/intrinsic/crf_decoding.h index a4081cfc34b..24179d90ddc 100644 --- a/paddle/fluid/operators/jit/more/intrinsic/crf_decoding.h +++ b/paddle/fluid/operators/jit/more/intrinsic/crf_decoding.h @@ -26,10 +26,12 @@ namespace intrinsic { void CRFDecoding(const int seq_len, const float* x, const float* w, float* alpha, int* track, int tag_num); -class CRFDecodingKernel : public KernelImpl> { +class CRFDecodingKernel : public KernelMore> { public: CRFDecodingKernel() { this->func = CRFDecoding; } - bool UseMe(typename CRFDecodingTuples::attr_type) const override; + bool UseMe( + const typename CRFDecodingTuples::attr_type&) const override; + const char* ImplType() const override { return "Intrinsic"; } }; } // namespace intrinsic diff --git a/paddle/fluid/operators/jit/more/intrinsic/layer_norm.cc b/paddle/fluid/operators/jit/more/intrinsic/layer_norm.cc index fafc12914e3..bac709bc9ea 100644 --- a/paddle/fluid/operators/jit/more/intrinsic/layer_norm.cc +++ b/paddle/fluid/operators/jit/more/intrinsic/layer_norm.cc @@ -153,7 +153,7 @@ void LayerNorm(float* x, float* out, float* mean, float* var, } } -bool LayerNormKernel::UseMe(int d) const { +bool LayerNormKernel::UseMe(const int& d) const { return platform::MayIUse(platform::avx); } diff --git a/paddle/fluid/operators/jit/more/intrinsic/layer_norm.h b/paddle/fluid/operators/jit/more/intrinsic/layer_norm.h index b802f56f57f..89da2940f44 100644 --- a/paddle/fluid/operators/jit/more/intrinsic/layer_norm.h +++ b/paddle/fluid/operators/jit/more/intrinsic/layer_norm.h @@ -27,10 +27,11 @@ void LayerNorm(float* x, float* out, float* mean, float* var, const float* scale, const float* bias, int height, const float epsilon, int right); -class LayerNormKernel : public KernelImpl> { +class LayerNormKernel : public KernelMore> { public: LayerNormKernel() { this->func = LayerNorm; } - bool UseMe(typename LayerNormTuples::attr_type) const override; + bool UseMe(const typename LayerNormTuples::attr_type&) const override; + const char* ImplType() const override { return "Intrinsic"; } }; } // namespace intrinsic diff --git a/paddle/fluid/operators/jit/more/mix/mix.cc b/paddle/fluid/operators/jit/more/mix/mix.cc index d8d5e30d010..924278eaa23 100644 --- a/paddle/fluid/operators/jit/more/mix/mix.cc +++ b/paddle/fluid/operators/jit/more/mix/mix.cc @@ -180,19 +180,19 @@ void GRUHtPart2(gru_t* step, const gru_attr_t* attr) { } // TODO(TJ): tuning me -bool VSigmoidKernel::UseMe(int d) const { return true; } +bool VSigmoidKernel::UseMe(const int& d) const { return true; } -bool VTanhKernel::UseMe(int d) const { return true; } +bool VTanhKernel::UseMe(const int& d) const { return true; } -bool LSTMCtHtKernel::UseMe(lstm_attr_t attr) const { return true; } +bool LSTMCtHtKernel::UseMe(const lstm_attr_t& attr) const { return true; } -bool LSTMC1H1Kernel::UseMe(lstm_attr_t attr) const { return true; } +bool LSTMC1H1Kernel::UseMe(const lstm_attr_t& attr) const { return true; } -bool GRUH1Kernel::UseMe(gru_attr_t attr) const { return true; } +bool GRUH1Kernel::UseMe(const gru_attr_t& attr) const { return true; } -bool GRUHtPart1Kernel::UseMe(gru_attr_t attr) const { return true; } +bool GRUHtPart1Kernel::UseMe(const gru_attr_t& attr) const { return true; } -bool GRUHtPart2Kernel::UseMe(gru_attr_t attr) const { return true; } +bool GRUHtPart2Kernel::UseMe(const gru_attr_t& attr) const { return true; } } // namespace mix } // namespace more diff --git a/paddle/fluid/operators/jit/more/mix/mix.h b/paddle/fluid/operators/jit/more/mix/mix.h index 85c8fd4c321..a70ecdf9348 100644 --- a/paddle/fluid/operators/jit/more/mix/mix.h +++ b/paddle/fluid/operators/jit/more/mix/mix.h @@ -33,11 +33,12 @@ void GRUH1(gru_t* step, const gru_attr_t* attr); void GRUHtPart1(gru_t* step, const gru_attr_t* attr); void GRUHtPart2(gru_t* step, const gru_attr_t* attr); -#define DECLARE_MORE_KERNEL(name, tuples) \ - class name##Kernel : public KernelImpl> { \ - public: \ - name##Kernel() { this->func = name; } \ - bool UseMe(typename tuples::attr_type) const override; \ +#define DECLARE_MORE_KERNEL(name, tuples) \ + class name##Kernel : public KernelMore> { \ + public: \ + name##Kernel() { this->func = name; } \ + bool UseMe(const typename tuples::attr_type&) const override; \ + const char* ImplType() const override { return "Mixed"; } \ } // XYN diff --git a/paddle/fluid/operators/jit/more/mkl/mkl.cc b/paddle/fluid/operators/jit/more/mkl/mkl.cc index 42f6df576b1..f4e334d88d4 100644 --- a/paddle/fluid/operators/jit/more/mkl/mkl.cc +++ b/paddle/fluid/operators/jit/more/mkl/mkl.cc @@ -74,39 +74,39 @@ void VExp(const double* x, double* y, int n) { // TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512 template <> -bool VMulKernel::UseMe(int d) const { +bool VMulKernel::UseMe(const int& d) const { return platform::MayIUse(platform::avx512f) && d > 512; } template <> -bool VAddKernel::UseMe(int d) const { +bool VAddKernel::UseMe(const int& d) const { return platform::MayIUse(platform::avx512f) && d > 512; } template <> -bool VScalKernel::UseMe(int d) const { +bool VScalKernel::UseMe(const int& d) const { return platform::MayIUse(platform::avx512f) && d > 512; } template <> -bool VExpKernel::UseMe(int d) const { +bool VExpKernel::UseMe(const int& d) const { return d > 7; } template <> -bool VSigmoidKernel::UseMe(int d) const { +bool VSigmoidKernel::UseMe(const int& d) const { return d > 7; } template <> -bool VTanhKernel::UseMe(int d) const { +bool VTanhKernel::UseMe(const int& d) const { return d > 7; } -#define AWALYS_USE_ME_WITH_DOUBLE(func) \ - template <> \ - bool func##Kernel::UseMe(int d) const { \ - return true; \ +#define AWALYS_USE_ME_WITH_DOUBLE(func) \ + template <> \ + bool func##Kernel::UseMe(const int& d) const { \ + return true; \ } AWALYS_USE_ME_WITH_DOUBLE(VMul); diff --git a/paddle/fluid/operators/jit/more/mkl/mkl.h b/paddle/fluid/operators/jit/more/mkl/mkl.h index bf209d2f9d2..ee1031c028f 100644 --- a/paddle/fluid/operators/jit/more/mkl/mkl.h +++ b/paddle/fluid/operators/jit/more/mkl/mkl.h @@ -60,12 +60,13 @@ void VTanh(const T* x, T* y, int n) { } } -#define DECLARE_MKL_KERNEL(name, tuples) \ - template \ - class name##Kernel : public KernelImpl> { \ - public: \ - name##Kernel() { this->func = name; } \ - bool UseMe(typename tuples::attr_type) const override; \ +#define DECLARE_MKL_KERNEL(name, tuples) \ + template \ + class name##Kernel : public KernelMore> { \ + public: \ + name##Kernel() { this->func = name; } \ + bool UseMe(const typename tuples::attr_type&) const override; \ + const char* ImplType() const override { return "MKL"; } \ } // XYZN diff --git a/paddle/fluid/operators/jit/test.cc b/paddle/fluid/operators/jit/test.cc index 32937d9c005..5be7cc5d1c8 100644 --- a/paddle/fluid/operators/jit/test.cc +++ b/paddle/fluid/operators/jit/test.cc @@ -228,10 +228,10 @@ void TestAllImpls(const typename KernelTuples::attr_type& attr, Args... args) { if (iter != pool.end()) { auto& impls = iter->second; for (auto& impl : impls) { - auto i = dynamic_cast*>(impl.get()); + auto i = dynamic_cast*>(impl.get()); if (i && i->UseMe(attr)) { auto more = i->GetFunc(); - VLOG(10) << "Test More Kernel "; + VLOG(10) << "Test More Kernel : " << i->ImplType(); test(more, args...); } } diff --git a/paddle/fluid/operators/math/jit_code.cc b/paddle/fluid/operators/math/jit_code.cc deleted file mode 100644 index 2b08c105971..00000000000 --- a/paddle/fluid/operators/math/jit_code.cc +++ /dev/null @@ -1,334 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/math/jit_code.h" -#include // offsetof -#include "paddle/fluid/operators/math/jit_kernel.h" // TODO(TJ): remove me - -namespace paddle { -namespace operators { -namespace math { -namespace jitkernel { -namespace gen { - -using namespace platform; // NOLINT - -bool VXXJitCode::init(int d, int scalar_index) { - // It's not necessary to use avx512 since it would slow down the frequency - // and this kernel is not compute bound. - return MayIUse(avx) && scalar_index >= 0 && scalar_index <= 2; -} - -void VXXJitCode::generate() { - // do not need push stack, and do not need save avx512reg if do not use avx512 - int offset = 0; - if (with_relu_) { - vxorps(ymm_zero, ymm_zero, ymm_zero); - } - if (scalar_index_ == 1) { - vbroadcastss(ymm_src1, ptr[param1]); - } else if (scalar_index_ == 2) { - vbroadcastss(ymm_src2, ptr[param2]); - } - for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) { - if (scalar_index_ != 1) { - vmovups(ymm_src1, ptr[param1 + offset]); - } - if (scalar_index_ != 2) { - vmovups(ymm_src2, ptr[param2 + offset]); - } - if (type_ == operand_type::mul) { - vmulps(ymm_dst, ymm_src1, ymm_src2); - } else if (type_ == operand_type::add) { - vaddps(ymm_dst, ymm_src1, ymm_src2); - } - if (with_relu_) { - vmaxps(ymm_dst, ymm_zero, ymm_dst); - } - vmovups(ptr[param3 + offset], ymm_dst); - offset += sizeof(float) * YMM_FLOAT_BLOCK; - } - int rest = num_ % YMM_FLOAT_BLOCK; - while (rest > 0) { - int block = XMM_FLOAT_BLOCK; - if (rest >= 4) { - block = 4; - if (scalar_index_ != 1) { - vmovups(xmm_src1, ptr[param1 + offset]); - } - if (scalar_index_ != 2) { - vmovups(xmm_src2, ptr[param2 + offset]); - } - } else if (rest >= 2) { - block = 2; - if (scalar_index_ != 1) { - vmovq(xmm_src1, ptr[param1 + offset]); - } - if (scalar_index_ != 2) { - vmovq(xmm_src2, ptr[param2 + offset]); - } - } else { - block = 1; - if (scalar_index_ != 1) { - vmovss(xmm_src1, ptr[param1 + offset]); - } - if (scalar_index_ != 2) { - vmovss(xmm_src2, ptr[param2 + offset]); - } - } - switch (type_) { - case operand_type::mul: - vmulps(xmm_dst, xmm_src1, xmm_src2); - break; - case operand_type::add: - vaddps(xmm_dst, xmm_src1, xmm_src2); - break; - default: - break; - } - if (with_relu_) { - vmaxps(xmm_dst, xmm_zero, xmm_dst); - } - if (rest >= 4) { - vmovups(ptr[param3 + offset], xmm_dst); - } else if (rest >= 2) { - vmovq(ptr[param3 + offset], xmm_dst); - } else { - vmovss(ptr[param3 + offset], xmm_dst); - } - offset += sizeof(float) * block; - rest -= block; - } - ret(); -} - -const float ALIGN32_BEG exp_float_consts[] ALIGN32_END = { - REPEAT_8TIMES(1.f), - REPEAT_8TIMES(2.f), - REPEAT_8TIMES(0.5f), - REPEAT_8TIMES(EXP_HIG), - REPEAT_8TIMES(EXP_LOW), - REPEAT_8TIMES(CEPHES_LOG2EF), - REPEAT_8TIMES(CEPHES_EXP_C1), - REPEAT_8TIMES(CEPHES_EXP_C2), - REPEAT_8TIMES(CEPHES_EXP_P0), - REPEAT_8TIMES(CEPHES_EXP_P1), - REPEAT_8TIMES(CEPHES_EXP_P2), - REPEAT_8TIMES(CEPHES_EXP_P3), - REPEAT_8TIMES(CEPHES_EXP_P4), - REPEAT_8TIMES(CEPHES_EXP_P5), - REPEAT_8TIMES(EXP_MAX_INPUT), - REPEAT_8TIMES(SIGMOID_THRESHOLD_MAX), - REPEAT_8TIMES(SIGMOID_THRESHOLD_MIN)}; - -const int ALIGN32_BEG exp_int_0x7f[] ALIGN32_END = {REPEAT_8TIMES(0x7f)}; -int ALIGN32_BEG g_tmp_mem[16] ALIGN32_END = {0}; - -bool VActJitCode::init(int d, operand_type type) { - // TODO(TJ): implement avx512, avx_exp is slower than mkl when d >= 256 - return MayIUse(avx); -} - -void VActJitCode::generate() { - int offset = 0; - for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) { - vmovups(ymm_src, ptr[param1 + offset]); - act(ymm_dst, ymm_src, type_); - vmovups(ptr[param2 + offset], ymm_dst); - offset += sizeof(float) * YMM_FLOAT_BLOCK; - } - int rest = num_ % YMM_FLOAT_BLOCK; - while (rest > 0) { - int block = XMM_FLOAT_BLOCK; - if (rest >= 4) { - block = 4; - vmovups(xmm_src, ptr[param1 + offset]); - } else if (rest >= 2) { - block = 2; - vmovq(xmm_src, ptr[param1 + offset]); - } else { - block = 1; - vmovss(xmm_src, ptr[param1 + offset]); - } - act(xmm_dst, xmm_src, type_); - if (rest >= 4) { - vmovups(ptr[param2 + offset], xmm_dst); - } else if (rest >= 2) { - vmovq(ptr[param2 + offset], xmm_dst); - } else { - vmovss(ptr[param2 + offset], xmm_dst); - } - offset += sizeof(float) * block; - rest -= block; - } - ret(); -} - -bool LSTMJitCode::init(int d) { return MayIUse(avx) && d % 8 == 0; } - -void LSTMJitCode::generate() { - if (use_peephole_) { - preCode(); - } - reg64_t reg_ptr_gates = rax; - reg64_t reg_ptr_ct_1 = r9; - reg64_t reg_ptr_ct = r10; - reg64_t reg_ptr_ht = r11; - reg64_t reg_ptr_wp = r12; - mov(reg_ptr_gates, ptr[param1 + offsetof(lstm_t, gates)]); - mov(reg_ptr_ct_1, ptr[param1 + offsetof(lstm_t, ct_1)]); - mov(reg_ptr_ct, ptr[param1 + offsetof(lstm_t, ct)]); - mov(reg_ptr_ht, ptr[param1 + offsetof(lstm_t, ht)]); - if (use_peephole_) { - mov(reg_ptr_wp, ptr[param1 + offsetof(lstm_t, wp)]); - } - - int offset = 0; - int d = num_ * sizeof(float); - for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) { - /* gates: W_ch, W_ih, W_fh, W_oh */ - ymm_t ymm_c = ymm_t(0); - ymm_t ymm_i = ymm_t(1); - ymm_t ymm_f = ymm_t(2); - ymm_t ymm_o = ymm_t(3); - ymm_t ymm_ct_1 = ymm_t(4); - ymm_t ymm_wp0 = ymm_t(5); - ymm_t ymm_wp1 = ymm_t(6); - ymm_t ymm_wp2 = ymm_t(7); - vmovups(ymm_c, ptr[reg_ptr_gates + offset]); - vmovups(ymm_i, ptr[reg_ptr_gates + offset + d]); - vmovups(ymm_f, ptr[reg_ptr_gates + offset + 2 * d]); - vmovups(ymm_o, ptr[reg_ptr_gates + offset + 3 * d]); - if (!compute_c1h1_) { - vmovups(ymm_ct_1, ptr[reg_ptr_ct_1 + offset]); - } - if (use_peephole_) { - vmovups(ymm_wp0, ptr[reg_ptr_wp + offset]); - vmovups(ymm_wp1, ptr[reg_ptr_wp + offset + d]); - vmovups(ymm_wp2, ptr[reg_ptr_wp + offset + 2 * d]); - } - /* C_t = act_cand(c) * act_gate(i) + C_t-1 * act_gate(f) */ - // act_cand(c) - act(ymm_c, ymm_c, act_cand_); - // act_gate(i) or act_gate(ct_1 * wp0 + i) - if (!compute_c1h1_ && use_peephole_) { - vmulps(ymm_wp0, ymm_ct_1, ymm_wp0); - vaddps(ymm_i, ymm_i, ymm_wp0); - } - act(ymm_i, ymm_i, act_gate_); - vmulps(ymm_c, ymm_c, ymm_i); - if (!compute_c1h1_) { - // act_gate(f) or act_gate(ct_1 * wp1 + f) - if (use_peephole_) { - vmulps(ymm_wp1, ymm_ct_1, ymm_wp1); - vaddps(ymm_f, ymm_f, ymm_wp1); - } - act(ymm_f, ymm_f, act_gate_); - // ct - vmulps(ymm_f, ymm_f, ymm_ct_1); - vaddps(ymm_f, ymm_f, ymm_c); - } - /* H_t = act_cell(C_t) * act_gate(o) */ - // act_cell(C_t) - ymm_t ymm_ct = compute_c1h1_ ? ymm_c : ymm_f; - ymm_t ymm_tmp = ymm_i; - act(ymm_tmp, ymm_ct, act_cell_); - // act_gate(o) or act_gate(ct * wp2 + o) - if (use_peephole_) { - vmulps(ymm_wp2, ymm_ct, ymm_wp2); - vaddps(ymm_o, ymm_o, ymm_wp2); - } - act(ymm_o, ymm_o, act_gate_); - // ht - vmulps(ymm_o, ymm_o, ymm_tmp); - // save ct and ht - vmovups(ptr[reg_ptr_ct + offset], ymm_ct); - vmovups(ptr[reg_ptr_ht + offset], ymm_o); - offset += sizeof(float) * YMM_FLOAT_BLOCK; - } - - if (use_peephole_) { - postCode(); - } else { - ret(); - } -} - -bool GRUJitCode::init(int d) { return MayIUse(avx) && d % 8 == 0; } - -void GRUJitCode::generate() { - reg64_t reg_ptr_gates = rax; - reg64_t reg_ptr_ht_1 = r9; - reg64_t reg_ptr_ht = r10; - mov(reg_ptr_gates, ptr[param1 + offsetof(gru_t, gates)]); - mov(reg_ptr_ht_1, ptr[param1 + offsetof(gru_t, ht_1)]); - mov(reg_ptr_ht, ptr[param1 + offsetof(gru_t, ht)]); - ymm_t ymm_one = ymm_t(0); - - if (id_ == 2) { - reg64_t reg_ptr_tmp = r11; - mov(reg_ptr_tmp, reinterpret_cast(exp_float_consts)); - vmovaps(ymm_one, ptr[reg_ptr_tmp + OFFSET_EXP_ONE]); - } - int offset = 0; - int d = num_ * sizeof(float); - for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) { - ymm_t ymm_u = ymm_t(1); - ymm_t ymm_r = ymm_t(2); - ymm_t ymm_s = ymm_t(3); - ymm_t ymm_ht_1 = ymm_t(4); - // W: {W_update, W_reset; W_state} - if (id_ == 0 || id_ == 2) { - vmovups(ymm_u, ptr[reg_ptr_gates + offset]); - vmovups(ymm_s, ptr[reg_ptr_gates + offset + 2 * d]); - } - if (id_ == 1) { - vmovups(ymm_r, ptr[reg_ptr_gates + offset + d]); - } - if (id_ == 1 || id_ == 2) { - vmovups(ymm_ht_1, ptr[reg_ptr_ht_1 + offset]); - } - - if (id_ == 0) { - // ht = act_gate(u) * act_cand(s) - act(ymm_u, ymm_u, act_gate_); - act(ymm_s, ymm_s, act_cand_); - vmulps(ymm_s, ymm_s, ymm_u); - vmovups(ptr[reg_ptr_ht + offset], ymm_s); - } else if (id_ == 1) { - // ht = act_gate(r) * ht_1 - act(ymm_r, ymm_r, act_gate_); - vmulps(ymm_r, ymm_r, ymm_ht_1); - vmovups(ptr[reg_ptr_ht + offset], ymm_r); - } else if (id_ == 2) { - // ht = act_gate(u) * act_cand(s) + (1-act_gate(u)) * ht_1 - ymm_t ymm_one_inner = ymm_t(ymm_one.getIdx()); - act(ymm_u, ymm_u, act_gate_); - act(ymm_s, ymm_s, act_cand_); - vmulps(ymm_s, ymm_s, ymm_u); - vsubps(ymm_u, ymm_one_inner, ymm_u); - vmulps(ymm_u, ymm_ht_1, ymm_u); - vaddps(ymm_u, ymm_s, ymm_u); - vmovups(ptr[reg_ptr_ht + offset], ymm_u); - } - offset += sizeof(float) * YMM_FLOAT_BLOCK; - } - - ret(); -} -} // namespace gen -} // namespace jitkernel -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/jit_code.h b/paddle/fluid/operators/math/jit_code.h deleted file mode 100644 index 6d22bf67572..00000000000 --- a/paddle/fluid/operators/math/jit_code.h +++ /dev/null @@ -1,532 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include "paddle/fluid/operators/math/jit_gen.h" -#include "paddle/fluid/operators/math/jit_kernel_impl.h" -#include "paddle/fluid/platform/cpu_info.h" - -namespace paddle { -namespace operators { -namespace math { -namespace jitkernel { -namespace gen { - -using reg64_t = const Xbyak::Reg64; -using reg32_t = const Xbyak::Reg32; -using xmm_t = const Xbyak::Xmm; -using ymm_t = const Xbyak::Ymm; -using zmm_t = const Xbyak::Zmm; -using Label = Xbyak::Label; - -typedef enum { - mul = 0, - add, - sub, - relu, - exp, - sigmoid, - tanh, - identity -} operand_type; - -extern const float exp_float_consts[]; -extern const int exp_int_0x7f[]; -extern int g_tmp_mem[]; - -#define EXP_HIG 88.3762626647949f -#define EXP_LOW -88.3762626647949f -#define CEPHES_LOG2EF 1.44269504088896341 -#define CEPHES_EXP_C1 0.693359375 -#define CEPHES_EXP_C2 -2.12194440e-4 -#define CEPHES_EXP_P0 1.9875691500E-4 -#define CEPHES_EXP_P1 1.3981999507E-3 -#define CEPHES_EXP_P2 8.3334519073E-3 -#define CEPHES_EXP_P3 4.1665795894E-2 -#define CEPHES_EXP_P4 1.6666665459E-1 -#define CEPHES_EXP_P5 5.0000001201E-1 - -#define REPEAT_8TIMES(val) val, val, val, val, val, val, val, val - -#define OFFSET_EXP_ONE 0 * YMM_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_TWO 1 * YMM_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_0P5 2 * YMM_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_HIG 3 * YMM_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_LOW 4 * YMM_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_LOG2EF 5 * YMM_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_C1 6 * YMM_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_C2 7 * YMM_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_P0 8 * YMM_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_P1 9 * YMM_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_P2 10 * YMM_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_P3 11 * YMM_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_P4 12 * YMM_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_P5 13 * YMM_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_MAX_INPUT 14 * YMM_FLOAT_BLOCK * sizeof(float) -#define OFFSET_SIGMOID_MAX 15 * YMM_FLOAT_BLOCK * sizeof(float) -#define OFFSET_SIGMOID_MIN 16 * YMM_FLOAT_BLOCK * sizeof(float) - -// function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu) -class VXXJitCode : public JitCode { - public: - const char* name() const override { - std::string base = "VXXJitCode"; - if (scalar_index_ == 1) { - base += "_Scalar"; - } else { - base += "_Vec"; - } - if (type_ == operand_type::mul) { - base += "_Mul"; - } else if (type_ == operand_type::add) { - base += "_Add"; - } - if (scalar_index_ == 2) { - base += "_Scalar"; - } else { - base += "_Vec"; - } - base += (with_relu_ ? "_Relu" : ""); - return base.c_str(); - } - explicit VXXJitCode(int d, operand_type type, int scalar_index, - bool with_relu, size_t code_size = 256 * 1024, - void* code_ptr = nullptr) - : JitCode(code_size, code_ptr), - num_(d), - type_(type), - scalar_index_(scalar_index), - with_relu_(with_relu) {} - static bool init(int d, int scalar_index = 0); - void generate() override; - - private: - int num_; - operand_type type_; - int scalar_index_; - bool with_relu_; - reg64_t param1{abi_param1}; - reg64_t param2{abi_param2}; - reg64_t param3{abi_param3}; - - xmm_t xmm_src1 = xmm_t(0); - xmm_t xmm_src2 = xmm_t(1); - xmm_t xmm_dst = xmm_t(2); - xmm_t xmm_zero = xmm_t(3); - - ymm_t ymm_src1 = ymm_t(0); - ymm_t ymm_src2 = ymm_t(1); - ymm_t ymm_dst = ymm_t(2); - ymm_t ymm_zero = ymm_t(3); -}; - -class VActJitCode : public JitCode { - public: - const char* name() const override { - std::string base = "VActJitCode"; - switch (type_) { - case operand_type::relu: - base += "_Relu"; - break; - case operand_type::exp: - base += "_Exp"; - break; - case operand_type::sigmoid: - base += "_Sigmoid"; - break; - case operand_type::tanh: - base += "_Tanh"; - break; - case operand_type::identity: - base += "_Identity"; - break; - default: - break; - } - return base.c_str(); - } - - explicit VActJitCode(int d, operand_type type, size_t code_size = 256 * 1024, - void* code_ptr = nullptr) - : JitCode(code_size, code_ptr), num_(d), type_(type) {} - static bool init(int d, operand_type type); - void generate() override; - - protected: - // compute relu with ymm, xmm - template - void relu_jmm(JMM& dst, JMM& src, int zero_idx = 15) { // NOLINT - JMM zero = JMM(zero_idx); - vxorps(zero, zero, zero); - vmaxps(dst, src, zero); - } - - // compute exp with ymm, xmm - template - void exp_jmm(JMM& dst, JMM& src, int src_idx = 11, int fx_idx = 12, // NOLINT - int fy_idx = 13, int mask_idx = 14, int tmp_idx = 15) { - using namespace platform; // NOLINT - // check all idx can not equal - JMM jmm_src = JMM(src_idx); - JMM jmm_fx = JMM(fx_idx); - JMM jmm_fy = JMM(fy_idx); - JMM jmm_mask = JMM(mask_idx); - JMM jmm_tmp = JMM(tmp_idx); - reg64_t reg_ptr_global = rax; - push(reg_ptr_global); - vmovaps(jmm_src, src); - mov(reg_ptr_global, reinterpret_cast(exp_float_consts)); - vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_HIG]); - vminps(jmm_src, jmm_src, jmm_tmp); - vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOW]); - vmaxps(jmm_src, jmm_src, jmm_tmp); - // express exp(x) as exp(g + n*log(2)) - vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOG2EF]); - vmulps(jmm_fx, jmm_src, jmm_tmp); - vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_0P5]); - vaddps(jmm_fx, jmm_fx, jmm_tmp); - vroundps(jmm_fy, jmm_fx, 0x01); - // if greater, substract 1 - vcmpgtps(jmm_mask, jmm_fy, jmm_fx); - vmovaps(jmm_tmp, ptr[reg_ptr_global]); - vandps(jmm_mask, jmm_mask, jmm_tmp); - vsubps(jmm_fx, jmm_fy, jmm_mask); - vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C1]); - vmulps(jmm_fy, jmm_fx, jmm_tmp); - vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C2]); - JMM ymm_z = JMM(jmm_mask.getIdx()); - vmulps(ymm_z, jmm_fx, jmm_tmp); - vsubps(jmm_src, jmm_src, jmm_fy); - vsubps(jmm_src, jmm_src, ymm_z); - vmulps(ymm_z, jmm_src, jmm_src); - vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P0]); - vmulps(dst, jmm_src, jmm_tmp); - for (size_t i = OFFSET_EXP_P1; i < OFFSET_EXP_P5; - i += (YMM_FLOAT_BLOCK * sizeof(float))) { - vmovaps(jmm_tmp, ptr[reg_ptr_global + i]); // P1~P4 - vaddps(dst, dst, jmm_tmp); - vmulps(dst, dst, jmm_src); - } - vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P5]); - vaddps(dst, dst, jmm_tmp); - vmulps(dst, dst, ymm_z); - vaddps(dst, dst, jmm_src); - vmovaps(jmm_tmp, ptr[reg_ptr_global]); - vaddps(dst, dst, jmm_tmp); - // build 2^n - JMM ymm_int = jmm_fx; - vcvttps2dq(ymm_int, jmm_fx); - mov(reg_ptr_global, reinterpret_cast(exp_int_0x7f)); - vmovdqa(jmm_tmp, ptr[reg_ptr_global]); - if (MayIUse(avx2) || std::is_same::value) { - vpaddd(ymm_int, ymm_int, jmm_tmp); - vpslld(ymm_int, ymm_int, 23); - } else if (MayIUse(avx)) { - xmm_t xtmp1 = xmm_t(ymm_int.getIdx()); - xmm_t xtmp2 = xmm_t(jmm_tmp.getIdx()); - reg64_t reg_ptr_tmp = reg_ptr_global; - mov(reg_ptr_tmp, reinterpret_cast(g_tmp_mem)); - vmovdqa(ptr[reg_ptr_tmp], ymm_int); - vmovdqa(ptr[reg_ptr_tmp + YMM_FLOAT_BLOCK * sizeof(float)], jmm_tmp); - vpaddd(xtmp1, xtmp1, xtmp2); - vpslld(xtmp1, xtmp1, 23); - vmovdqa(ptr[reg_ptr_tmp], xtmp1); - // next 128bits - vmovdqa(xtmp1, ptr[reg_ptr_tmp + XMM_FLOAT_BLOCK * sizeof(float)]); - vmovdqa(xtmp2, ptr[reg_ptr_tmp + - (YMM_FLOAT_BLOCK + XMM_FLOAT_BLOCK) * sizeof(float)]); - vpaddd(xtmp1, xtmp1, xtmp2); - vpslld(xtmp1, xtmp1, 23); - vmovdqa(ptr[reg_ptr_tmp + XMM_FLOAT_BLOCK * sizeof(float)], xtmp1); - // load out - vmovdqa(ymm_int, ptr[reg_ptr_tmp]); - } - vmulps(dst, dst, ymm_int); - pop(reg_ptr_global); - } - - // compute sigmoid with ymm, xmm - template - void sigmoid_jmm(JMM& dst, JMM& src, int src_idx = 11, // NOLINT - int fx_idx = 12, int fy_idx = 13, int mask_idx = 14, - int tmp_idx = 15) { - // y = 1 / (1 + e^-x) - JMM jmm_tmp = JMM(tmp_idx); - JMM jmm_src = JMM(src_idx); - reg64_t reg_ptr_global = rax; - push(reg_ptr_global); - vmovaps(jmm_src, src); - mov(reg_ptr_global, reinterpret_cast(exp_float_consts)); - vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MAX]); - vminps(jmm_src, jmm_src, jmm_tmp); - vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MIN]); - vmaxps(jmm_src, jmm_src, jmm_tmp); - vxorps(jmm_tmp, jmm_tmp, jmm_tmp); - vsubps(jmm_src, jmm_tmp, jmm_src); - exp_jmm(dst, jmm_src, src_idx, fx_idx, fy_idx, mask_idx, tmp_idx); - vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]); - vaddps(dst, dst, jmm_tmp); - vdivps(dst, jmm_tmp, dst); - pop(reg_ptr_global); - } - - // compute tanh with ymm, xmm - template - void tanh_jmm(JMM& dst, JMM& src, int src_idx = 11, // NOLINT - int fx_idx = 12, int fy_idx = 13, int mask_idx = 14, - int tmp_idx = 15) { - // y = 2 / (1 + e^(-2x)) - 1 - JMM jmm_src = JMM(src_idx); - JMM jmm_tmp = JMM(tmp_idx); - JMM jmm_zero = JMM(mask_idx); - reg64_t reg_ptr_global = rax; - push(reg_ptr_global); - vmovaps(jmm_src, src); - mov(reg_ptr_global, reinterpret_cast(exp_float_consts)); - vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]); - vxorps(jmm_zero, jmm_zero, jmm_zero); - vsubps(jmm_tmp, jmm_zero, jmm_tmp); - vmulps(jmm_src, jmm_src, jmm_tmp); - exp_jmm(dst, jmm_src, src_idx, fx_idx, fy_idx, mask_idx, tmp_idx); - vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]); - vaddps(dst, dst, jmm_tmp); - vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]); - vdivps(dst, jmm_tmp, dst); - vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]); - vsubps(dst, dst, jmm_tmp); - pop(reg_ptr_global); - } - - template - void act(JMM& dst, JMM& src, operand_type type) { // NOLINT - // use 11~15 - switch (type) { - case operand_type::relu: - relu_jmm(dst, src, 15); - break; - case operand_type::exp: - exp_jmm(dst, src, 11, 12, 13, 14, 15); - break; - case operand_type::sigmoid: - sigmoid_jmm(dst, src, 11, 12, 13, 14, 15); - break; - case operand_type::tanh: - tanh_jmm(dst, src, 11, 12, 13, 14, 15); - break; - case operand_type::identity: - break; - default: - // throw error - break; - } - } - - protected: - int num_; - operand_type type_; - reg64_t param1{abi_param1}; - reg64_t param2{abi_param2}; - - xmm_t xmm_src = xmm_t(0); - ymm_t ymm_src = ymm_t(0); - - xmm_t xmm_dst = xmm_t(1); - ymm_t ymm_dst = ymm_t(1); -}; - -class LSTMJitCode : public VActJitCode { - public: - const char* name() const override { - std::string base = "LSTMJitCode"; - if (use_peephole_) { - base += "_Peephole"; - } - if (compute_c1h1_) { - base += "_C1H1"; - } - auto AddTypeStr = [&](operand_type type) { - switch (type) { - case operand_type::relu: - base += "_Relu"; - break; - case operand_type::exp: - base += "_Exp"; - break; - case operand_type::sigmoid: - base += "_Sigmoid"; - break; - case operand_type::tanh: - base += "_Tanh"; - break; - case operand_type::identity: - base += "_Identity"; - break; - default: - break; - } - }; - AddTypeStr(act_gate_); - AddTypeStr(act_cand_); - AddTypeStr(act_cell_); - return base.c_str(); - } - - explicit LSTMJitCode(bool compute_c1h1, const lstm_attr_t& attr, - size_t code_size = 256 * 1024, void* code_ptr = nullptr) - : VActJitCode(attr.d, operand_type::sigmoid /* this is bugy*/, code_size, - code_ptr), - compute_c1h1_(compute_c1h1) { - auto typeExchange = [](const std::string& type) -> gen::operand_type { - if (type == "sigmoid") { - return operand_type::sigmoid; - } else if (type == "relu") { - return operand_type::relu; - } else if (type == "tanh") { - return operand_type::tanh; - } else if (type == "identity" || type == "") { - return operand_type::identity; - } // else throw error - return operand_type::identity; - }; - num_ = attr.d; - use_peephole_ = attr.use_peephole; - act_gate_ = typeExchange(attr.act_gate); - act_cand_ = typeExchange(attr.act_cand); - act_cell_ = typeExchange(attr.act_cell); - } - static bool init(int d); - void generate() override; - - protected: - int num_; - bool compute_c1h1_; - bool use_peephole_; - operand_type act_gate_; - operand_type act_cand_; - operand_type act_cell_; - reg64_t param1{abi_param1}; -}; - -class GRUJitCode : public VActJitCode { - public: - const char* name() const override { - std::string base = "GRUJitCode"; - if (id_ == 0) { - base += "_H1"; - } else if (id_ == 1) { - base += "_HtPart1"; - } else if (id_ == 2) { - base += "_HtPart2"; - } - auto AddTypeStr = [&](operand_type type) { - switch (type) { - case operand_type::relu: - base += "_Relu"; - break; - case operand_type::exp: - base += "_Exp"; - break; - case operand_type::sigmoid: - base += "_Sigmoid"; - break; - case operand_type::tanh: - base += "_Tanh"; - break; - case operand_type::identity: - base += "_Identity"; - break; - default: - break; - } - }; - AddTypeStr(act_gate_); - AddTypeStr(act_cand_); - return base.c_str(); - } - - explicit GRUJitCode(int id, const gru_attr_t& attr, - size_t code_size = 256 * 1024, void* code_ptr = nullptr) - : VActJitCode(attr.d, operand_type::sigmoid /* this is bugy*/, code_size, - code_ptr), - id_(id) { - auto typeExchange = [](const std::string& type) -> gen::operand_type { - if (type == "sigmoid") { - return operand_type::sigmoid; - } else if (type == "relu") { - return operand_type::relu; - } else if (type == "tanh") { - return operand_type::tanh; - } else if (type == "identity" || type == "") { - return operand_type::identity; - } // else throw error - return operand_type::identity; - }; - num_ = attr.d; - act_gate_ = typeExchange(attr.act_gate); - act_cand_ = typeExchange(attr.act_cand); - } - static bool init(int d); - void generate() override; - - protected: - int id_; - int num_; - operand_type act_gate_; - operand_type act_cand_; - reg64_t param1{abi_param1}; -}; - -#ifdef PADDLE_WITH_MKLDNN -struct EltwiseMulnChw16cNC : public Xbyak::CodeGenerator { - explicit EltwiseMulnChw16cNC(size_t code_size = 256 * 1024) - : Xbyak::CodeGenerator(code_size) { - // RDI is ptr x_input - // RSI is ptr y_input - // RDX is ptr output - // RCX is height - // r8 is width - - push(rbx); - - xor_(rax, rax); - xor_(r10, r10); - vmovups(zmm3, ptr[rsi]); - - L("h_loop"); - xor_(rbx, rbx); - L("w_loop"); - vmovups(zmm2, ptr[rdi + rax]); - vmulps(zmm1, zmm2, zmm3); - vmovups(ptr[rdx + rax], zmm1); - add(rax, 64); - inc(rbx); - cmp(r8, rbx); - jnz("w_loop"); - inc(r10); - cmp(r10, rcx); - jnz("h_loop"); - - pop(rbx); - ret(); - } -}; -#endif - -} // namespace gen -} // namespace jitkernel -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/jit_gen.cc b/paddle/fluid/operators/math/jit_gen.cc deleted file mode 100644 index 5c6672928e8..00000000000 --- a/paddle/fluid/operators/math/jit_gen.cc +++ /dev/null @@ -1,90 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/math/jit_gen.h" -#include -#include -#include -#include "paddle/fluid/platform/cpu_info.h" - -DEFINE_bool(dump_jitcode, false, "Whether to dump the jitcode to file"); - -namespace paddle { -namespace operators { -namespace math { -namespace jitkernel { -namespace gen { - -constexpr Xbyak::Operand::Code g_abi_regs[] = { - Xbyak::Operand::RBX, Xbyak::Operand::RBP, Xbyak::Operand::R12, - Xbyak::Operand::R13, Xbyak::Operand::R14, Xbyak::Operand::R15}; - -constexpr int num_g_abi_regs = sizeof(g_abi_regs) / sizeof(g_abi_regs[0]); - -void JitCode::preCode() { - for (int i = 0; i < num_g_abi_regs; ++i) { - push(Xbyak::Reg64(g_abi_regs[i])); - } - if (platform::MayIUse(platform::avx512f)) { - mov(reg_EVEX_max_8b_offt, 2 * EVEX_max_8b_offt); - } -} - -void JitCode::postCode() { - for (int i = 0; i < num_g_abi_regs; ++i) { - pop(Xbyak::Reg64(g_abi_regs[num_g_abi_regs - 1 - i])); - } - ret(); -} - -void JitCode::dumpCode(const Xbyak::uint8 *code) const { - if (code) { - static int counter = 0; - std::ostringstream filename; - filename << "paddle_jitcode_" << name() << "." << counter << ".bin"; - counter++; - std::ofstream fout(filename.str(), std::ios::out); - if (fout.is_open()) { - fout.write(reinterpret_cast(code), getSize()); - fout.close(); - } - } -} - -Xbyak::Address JitCode::EVEX_compress_addr(Xbyak::Reg64 base, int offt, - bool bcast) { - int scale = 0; - if (EVEX_max_8b_offt <= offt && offt < 3 * EVEX_max_8b_offt) { - offt = offt - 2 * EVEX_max_8b_offt; - scale = 1; - } else if (3 * EVEX_max_8b_offt <= offt && offt < 5 * EVEX_max_8b_offt) { - offt = offt - 4 * EVEX_max_8b_offt; - scale = 2; - } - auto re = Xbyak::RegExp() + base + offt; - if (scale) { - re = re + reg_EVEX_max_8b_offt * scale; - } - if (bcast) { - return zword_b[re]; - } else { - return zword[re]; - } -} - -} // namespace gen -} // namespace jitkernel -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/jit_gen.h b/paddle/fluid/operators/math/jit_gen.h deleted file mode 100644 index 6abf3434cc8..00000000000 --- a/paddle/fluid/operators/math/jit_gen.h +++ /dev/null @@ -1,80 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include -#include "paddle/fluid/platform/macros.h" - -#define XBYAK_USE_MMAP_ALLOCATOR -#include "xbyak/xbyak.h" -#include "xbyak/xbyak_util.h" - -DECLARE_bool(dump_jitcode); - -namespace paddle { -namespace operators { -namespace math { -namespace jitkernel { -namespace gen { - -#define DECLARE_JIT_CODE(codename) \ - const char *name() const override { return #codename; } - -// Application Binary Interface -constexpr Xbyak::Operand::Code abi_param1(Xbyak::Operand::RDI), - abi_param2(Xbyak::Operand::RSI), abi_param3(Xbyak::Operand::RDX), - abi_param4(Xbyak::Operand::RCX), abi_not_param1(Xbyak::Operand::RCX); - -class JitCode : public Xbyak::CodeGenerator { - public: - explicit JitCode(size_t code_size = 256 * 1024, void *code_ptr = nullptr) - : Xbyak::CodeGenerator(code_size, code_ptr) {} - - virtual ~JitCode() {} - virtual const char *name() const = 0; - virtual void generate() = 0; - - template - const FUNC getCode() { - this->generate(); - const Xbyak::uint8 *code = CodeGenerator::getCode(); - if (FLAGS_dump_jitcode) { - this->dumpCode(code); - } - return reinterpret_cast(code); - } - DISABLE_COPY_AND_ASSIGN(JitCode); - - protected: - Xbyak::Reg64 param1{abi_param1}; - const int EVEX_max_8b_offt = 0x200; - const Xbyak::Reg64 reg_EVEX_max_8b_offt = rbp; - - void preCode(); - void postCode(); - void dumpCode(const Xbyak::uint8 *code) const; - void L(const char *label) { Xbyak::CodeGenerator::L(label); } - void L(const Xbyak::Label &label) { Xbyak::CodeGenerator::L(label); } - // Enhanced vector extension - Xbyak::Address EVEX_compress_addr(Xbyak::Reg64 base, int offt, - bool bcast = false); -}; - -} // namespace gen -} // namespace jitkernel -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/jit_kernel.cc b/paddle/fluid/operators/math/jit_kernel.cc deleted file mode 100644 index 118696ba479..00000000000 --- a/paddle/fluid/operators/math/jit_kernel.cc +++ /dev/null @@ -1,39 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/math/jit_kernel.h" -#include -#include - -namespace paddle { -namespace operators { -namespace math { -namespace jitkernel { - -KernelPool& KernelPool::Instance() { - static thread_local KernelPool g_jit_kernels; - return g_jit_kernels; -} - -std::shared_ptr KernelPool::Get(const std::string& key) const { - if (kers_.find(key) == kers_.end()) { - return nullptr; - } - return kers_.at(key); -} - -} // namespace jitkernel -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/jit_kernel.h b/paddle/fluid/operators/math/jit_kernel.h deleted file mode 100644 index b78b92b4f97..00000000000 --- a/paddle/fluid/operators/math/jit_kernel.h +++ /dev/null @@ -1,157 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include // for shared_ptr -#include -#include -#include "paddle/fluid/operators/math/jit_kernel_impl.h" -#include "paddle/fluid/platform/cpu_info.h" -#include "paddle/fluid/platform/macros.h" - -// Note: Only support on CPU yet. -namespace paddle { -namespace operators { -namespace math { -namespace jitkernel { - -// TODO(TJ): remove me -typedef enum { kLT8, kEQ8, kGT8LT16, kEQ16, kGT16 } jit_block; - -class Kernel { - public: - Kernel() = default; - virtual ~Kernel() = default; - // TODO(TJ): below members should be deprecated. - int num_{0}; - int end_{0}; - int rest_{0}; - DISABLE_COPY_AND_ASSIGN(Kernel); -}; - -class KernelPool { - public: - static KernelPool &Instance(); - - template - std::shared_ptr Get(ARGS... args); - - std::shared_ptr Get(const std::string &key) const; - - private: - KernelPool() = default; - std::unordered_map> kers_; - - DISABLE_COPY_AND_ASSIGN(KernelPool); -}; - -template -class VMulKernel : public Kernel { - public: - void (*Compute)(const T *, const T *, T *, int); -}; - -template -class VAddKernel : public Kernel { - public: - void (*Compute)(const T *, const T *, T *, int); -}; - -template -class VAddReluKernel : public Kernel { - public: - void (*Compute)(const T *, const T *, T *, int); -}; - -template -class VScalKernel : public Kernel { - public: - // y = a.*x - void (*Compute)(const T *, const T *, T *, int); -}; - -template -class VAddBiasKernel : public Kernel { - public: - // y = a.+x - void (*Compute)(const T *, const T *, T *, int); -}; - -#ifdef PADDLE_WITH_MKLDNN -template -class EltwiseMulnChw16cNCKernel : public Kernel { - public: - // nChw16c = nChw16c .* NC - void (*Compute)(const float *, const float *, float *, int, int); -}; -#endif - -template -class VActKernel : public Kernel { - public: - void (*Compute)(const T *, T *, int); -}; - -template -class VReluKernel : public VActKernel {}; - -template -class VIdentityKernel : public VActKernel {}; - -template -class VExpKernel : public VActKernel {}; - -template -class VSigmoidKernel : public VActKernel {}; - -template -class VTanhKernel : public VActKernel {}; - -template -class LSTMKernel : public Kernel { - public: - // compute c1 and h1 without c0 or h0 - void (*ComputeC1H1)(lstm_t *, const lstm_attr_t *); - void (*ComputeCtHt)(lstm_t *, const lstm_attr_t *); -}; - -template -class GRUKernel : public Kernel { - public: - // compute h1 without h0 - void (*ComputeH1)(gru_t *, const gru_attr_t *); - void (*ComputeHtPart1)(gru_t *, const gru_attr_t *); - void (*ComputeHtPart2)(gru_t *, const gru_attr_t *); -}; - -template -class CRFDecodeKernel : public Kernel { - public: - virtual void Compute(const int seq_len, const T *x, const T *w, T *alpha, - int *track) const = 0; -}; - -template -class LayerNormKernel : public Kernel { - public: - virtual void Compute(T *x, T *out, T *mean, T *var, const T *scale, - const T *bias, int height, - const float epsilon) const = 0; -}; - -} // namespace jitkernel -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/jit_kernel_blas.cc b/paddle/fluid/operators/math/jit_kernel_blas.cc deleted file mode 100644 index 682e51e89d6..00000000000 --- a/paddle/fluid/operators/math/jit_kernel_blas.cc +++ /dev/null @@ -1,346 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/math/jit_kernel.h" -#include -#include "paddle/fluid/operators/math/jit_kernel_macro.h" -#include "paddle/fluid/operators/math/jit_kernel_refer.h" -#include "paddle/fluid/platform/enforce.h" - -#ifdef PADDLE_WITH_XBYAK -#include "paddle/fluid/operators/math/jit_code.h" -#endif - -#ifdef PADDLE_WITH_MKLML -#include "paddle/fluid/platform/dynload/mklml.h" -#endif - -namespace paddle { -namespace operators { -namespace math { -namespace jitkernel { - -/* VMUL JitKernel */ -template -class VMulKernelImpl : public VMulKernel { - public: - JITKERNEL_DECLARE_STATIC_FUNC; - explicit VMulKernelImpl(int d) : VMulKernel() { -#ifdef PADDLE_WITH_XBYAK - if (useJIT(d)) { - // roughly estimate the size of code - size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8; - jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::mul, 0, false, - sz > 4096 ? sz : 4096)); - this->Compute = - jitcode_->getCode(); - return; - } -#endif -#ifdef PADDLE_WITH_MKLML - if (useMKL(d)) { - this->Compute = VMulMKL; - return; - } -#endif - this->Compute = refer::VMul; - } - -#ifdef PADDLE_WITH_XBYAK - - private: - std::unique_ptr jitcode_{nullptr}; -#endif -}; - -#ifdef PADDLE_WITH_XBYAK -template <> -bool VMulKernelImpl::useJIT(int d) { - return gen::VXXJitCode::init(d); -} -#endif - -#ifdef PADDLE_WITH_MKLML -template <> -bool VMulKernelImpl::useMKL(int d) { - return platform::MayIUse(platform::avx512f) && d > 512; -} - -template <> -bool VMulKernelImpl::useMKL(int d) { - return true; -} -#endif - -/* VAdd JitKernel */ -template -class VAddKernelImpl : public VAddKernel { - public: - JITKERNEL_DECLARE_STATIC_FUNC; - explicit VAddKernelImpl(int d) : VAddKernel() { -#ifdef PADDLE_WITH_XBYAK - if (useJIT(d)) { - size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8; - jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 0, false, - sz > 4096 ? sz : 4096)); - this->Compute = - jitcode_->getCode(); - return; - } -#endif -#ifdef PADDLE_WITH_MKLML - if (useMKL(d)) { - this->Compute = VAddMKL; - return; - } -#endif - this->Compute = refer::VAdd; - } -#ifdef PADDLE_WITH_XBYAK - - private: - std::unique_ptr jitcode_{nullptr}; -#endif -}; - -#ifdef PADDLE_WITH_XBYAK -template <> -bool VAddKernelImpl::useJIT(int d) { - return gen::VXXJitCode::init(d); -} -#endif - -#ifdef PADDLE_WITH_MKLML -template <> -bool VAddKernelImpl::useMKL(int d) { - return d > 512; -} - -template <> -bool VAddKernelImpl::useMKL(int d) { - return true; -} -#endif - -#ifdef PADDLE_WITH_MKLDNN -/* EltwiseMul for nChw16c & NC inputs JitKernel */ -template -class EltwiseMulnChw16cNCKernelImpl - : public math::jitkernel::EltwiseMulnChw16cNCKernel { - public: - JITKERNEL_DECLARE_STATIC_FUNC; - explicit EltwiseMulnChw16cNCKernelImpl(int d) - : EltwiseMulnChw16cNCKernel() { - using mul_func_t = void (*)(const float*, const float*, float*, int, int); -#ifdef PADDLE_WITH_XBYAK - if (useJIT(d)) { - // roughly estimate the size of code - size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8; - sz = sz > 4096 ? sz : 4096; - jitcode_.reset(new gen::EltwiseMulnChw16cNC(sz)); - this->Compute = (mul_func_t)jitcode_->getCode(); - return; - } -#endif - PADDLE_THROW( - "This kernel shouldn't be used in Non-Xbyak, Non-MKL-DNN " - "environemnt"); - } - -#ifdef PADDLE_WITH_XBYAK - - private: - std::unique_ptr jitcode_{nullptr}; -}; - -template <> -bool EltwiseMulnChw16cNCKernelImpl::useJIT(int d) { - return true; -} -#endif -#endif - -/* VAddRelu JitKernel */ -template -class VAddReluKernelImpl : public VAddReluKernel { - public: - JITKERNEL_DECLARE_STATIC_FUNC; - explicit VAddReluKernelImpl(int d) : VAddReluKernel() { -#ifdef PADDLE_WITH_XBYAK - if (useJIT(d)) { - size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8; - jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 0, true, - sz > 4096 ? sz : 4096)); - this->Compute = - jitcode_->getCode(); - return; - } -#endif - this->Compute = refer::VAddRelu; - } -#ifdef PADDLE_WITH_XBYAK - - private: - std::unique_ptr jitcode_{nullptr}; -#endif -}; - -#ifdef PADDLE_WITH_XBYAK -template <> -bool VAddReluKernelImpl::useJIT(int d) { - return gen::VXXJitCode::init(d); -} -#endif - -/* VScal JitKernel */ -template -class VScalKernelImpl : public VScalKernel { - public: - JITKERNEL_DECLARE_STATIC_FUNC; - explicit VScalKernelImpl(int d) : VScalKernel() { -#ifdef PADDLE_WITH_XBYAK - if (useJIT(d)) { - size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8; - jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::mul, 1, false, - sz > 4096 ? sz : 4096)); - this->Compute = - jitcode_->getCode(); - return; - } -#endif -#ifdef PADDLE_WITH_MKLML - if (useMKL(d)) { - this->Compute = VScalMKL; - return; - } -#endif - this->Compute = refer::VScal; - } -#ifdef PADDLE_WITH_XBYAK - - private: - std::unique_ptr jitcode_{nullptr}; -#endif -}; - -#ifdef PADDLE_WITH_XBYAK -template <> -bool VScalKernelImpl::useJIT(int d) { - return gen::VXXJitCode::init(d, 1); -} -#endif - -#ifdef PADDLE_WITH_MKLML -template <> -bool VScalKernelImpl::useMKL(int d) { - return d > 512; -} -template <> -bool VScalKernelImpl::useMKL(int d) { - return true; -} -#endif - -/* VAddBias JitKernel */ -template -class VAddBiasKernelImpl : public VAddBiasKernel { - public: - JITKERNEL_DECLARE_STATIC_FUNC; - explicit VAddBiasKernelImpl(int d) : VAddBiasKernel() { -#ifdef PADDLE_WITH_XBYAK - if (useJIT(d)) { - size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8; - jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 1, false, - sz > 4096 ? sz : 4096)); - this->Compute = - jitcode_->getCode(); - return; - } -#endif - - this->Compute = refer::VAddBias; - } -#ifdef PADDLE_WITH_XBYAK - - private: - std::unique_ptr jitcode_{nullptr}; -#endif -}; - -#ifdef PADDLE_WITH_XBYAK -template <> -bool VAddBiasKernelImpl::useJIT(int d) { - return gen::VXXJitCode::init(d, 1); -} -#endif - -/* VRelu JitKernel */ -template -class VReluKernelImpl : public VReluKernel { - public: - JITKERNEL_DECLARE_STATIC_FUNC; - explicit VReluKernelImpl(int d) : VReluKernel() { -#ifdef PADDLE_WITH_XBYAK - if (useJIT(d)) { - size_t sz = 96 /* init size */ + - d / YMM_FLOAT_BLOCK * 4 /* instructions */ * - 8 /* average bytes for each instruction */; - jitcode_.reset(new gen::VActJitCode(d, gen::operand_type::relu, - sz > 4096 ? sz : 4096)); - this->Compute = jitcode_->getCode(); - return; - } -#endif - - this->Compute = refer::VRelu; - } -#ifdef PADDLE_WITH_XBYAK - - private: - std::unique_ptr jitcode_{nullptr}; -#endif -}; - -#ifdef PADDLE_WITH_XBYAK -template <> -bool VReluKernelImpl::useJIT(int d) { - return gen::VActJitCode::init(d, gen::operand_type::relu); -} -#endif - -/* An empty JitKernel */ -template -class VIdentityKernelImpl : public VIdentityKernel { - public: - JITKERNEL_DECLARE_STATIC_FUNC; - explicit VIdentityKernelImpl(int d) : VIdentityKernel() { - this->Compute = refer::VIdentity; - } -}; - -REGISTER_JITKERNEL(vmul, VMulKernel); -REGISTER_JITKERNEL(vadd, VAddKernel); -REGISTER_JITKERNEL(vaddrelu, VAddReluKernel); -REGISTER_JITKERNEL(vscal, VScalKernel); -REGISTER_JITKERNEL(vaddbias, VAddBiasKernel); -REGISTER_JITKERNEL(vrelu, VReluKernel); -REGISTER_JITKERNEL(videntity, VIdentityKernel); -#ifdef PADDLE_WITH_MKLDNN -REGISTER_JITKERNEL(eltwise_mul_nchw16c, EltwiseMulnChw16cNCKernel); -#endif - -} // namespace jitkernel -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/jit_kernel_crf_decode.cc b/paddle/fluid/operators/math/jit_kernel_crf_decode.cc deleted file mode 100644 index ac2d29f1c18..00000000000 --- a/paddle/fluid/operators/math/jit_kernel_crf_decode.cc +++ /dev/null @@ -1,291 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/math/jit_kernel.h" -#include -#include -#include "paddle/fluid/operators/math/jit_kernel_macro.h" - -namespace paddle { -namespace operators { -namespace math { -namespace jitkernel { - -/* CRF Decode JitKernel */ -template -class CRFDecodeKernelImpl : public CRFDecodeKernel { - public: - explicit CRFDecodeKernelImpl(int tag_num) : CRFDecodeKernel() { - this->num_ = tag_num; - } - void Compute(const int seq_len, const T* x, const T* w, T* alpha, - int* track) const override { - constexpr int state_trans_base_idx = 2; - for (int i = 0; i < this->num_; ++i) { - alpha[i] = w[i] + x[i]; - } - for (int k = 1; k < seq_len; ++k) { - for (int i = 0; i < this->num_; ++i) { - T max_score = -std::numeric_limits::max(); - int max_j = 0; - for (int j = 0; j < this->num_; ++j) { - T score = alpha[(k - 1) * this->num_ + j] + - w[(j + state_trans_base_idx) * this->num_ + i]; - if (score > max_score) { - max_score = score; - max_j = j; - } - } - alpha[k * this->num_ + i] = max_score + x[k * this->num_ + i]; - track[k * this->num_ + i] = max_j; - } - } - } -}; - -#define INIT_ALPHA(step_size) \ - /* Setup the alpha initial value.*/ \ - int i_offset = 0; \ - int last_offset = this->rest_ - step_size; \ - for (int i = 0; i <= this->end_; ++i) { \ - /* weights, input and alpha values. */ \ - __m256 w_content, x_content, alpha_content; \ - /* Load the relevant data into the variables from un-aligned address.*/ \ - w_content = _mm256_loadu_ps(w + i_offset); \ - x_content = _mm256_loadu_ps(x + i_offset); \ - alpha_content = _mm256_add_ps(w_content, x_content); \ - _mm256_storeu_ps(alpha + i_offset, alpha_content); \ - i_offset += step_size; \ - if (i == this->end_ - 1) { \ - if (this->rest_ > 0) { \ - i_offset += last_offset; \ - } else { \ - break; \ - } \ - } \ - } - -#define UPDATE_ALPHA(step_size) \ - /* Update the alpha and track values. */ \ - __m256 x_content = _mm256_loadu_ps(x + seq_offset + this->num_ + j_offset); \ - max_score = _mm256_add_ps(max_score, x_content); \ - _mm256_storeu_ps(alpha + seq_offset + this->num_ + j_offset, max_score); \ - _mm256_storeu_si256( \ - reinterpret_cast<__m256i*>(track + seq_offset + this->num_ + j_offset), \ - max_j); \ - /* Calculate the offset of next step*/ \ - j_offset += step_size; \ - if (j == this->end_ - 1) { \ - if (this->rest_ > 0) { \ - j_offset += last_offset; \ - } else { \ - break; \ - } \ - } - -#define INTRIAVX_FLOAT(block) \ - template <> \ - CRFDecodeKernelImpl::CRFDecodeKernelImpl( \ - int tag_num) \ - : CRFDecodeKernel() { \ - this->num_ = tag_num; \ - this->end_ = this->num_ / YMM_FLOAT_BLOCK; \ - this->rest_ = this->num_ % YMM_FLOAT_BLOCK; \ - } \ - template <> \ - void CRFDecodeKernelImpl::Compute( \ - const int seq_len, const float* x, const float* w, float* alpha, \ - int* track) const { \ - INIT_ALPHA(YMM_FLOAT_BLOCK) \ - /* Use the column-major strategy to get the location of maximum score.*/ \ - int seq_offset = 0; \ - constexpr int state_trans_base_idx = 2; \ - for (int k = 1; k < seq_len; ++k) { \ - int j_offset = 0; \ - for (int j = 0; j <= this->end_; ++j) { \ - /* Initialize the variables of maximum score and location.*/ \ - __m256 max_score = _mm256_set1_ps(-std::numeric_limits::max()); \ - __m256i max_j = _mm256_set1_epi32(0); \ - /* Calculate the offset of transition_weights.*/ \ - int trans_offset = state_trans_base_idx * this->num_ + j_offset; \ - for (int i = 0; i < this->num_; ++i) { \ - /* Initalize the content of alpha variable with related offset.*/ \ - __m256 alpha_content = _mm256_broadcast_ss(alpha + seq_offset + i); \ - /* Obtain the content of weights from un-aligned address.*/ \ - __m256 w_content = _mm256_loadu_ps(w + trans_offset); \ - __m256 score_v = _mm256_add_ps(alpha_content, w_content); \ - __m256 mask = _mm256_cmp_ps(score_v, max_score, _CMP_GT_OS); \ - /* According to the mask value, update the index of the max_score.*/ \ - /* AVX instructions.*/ \ - __m128i lo_max_j = _mm256_extractf128_si256(max_j, 0); \ - __m128i hi_max_j = _mm256_extractf128_si256(max_j, 1); \ - __m128i lo_mask = _mm256_extractf128_si256(*(__m256i*)&mask, 0); \ - __m128i hi_mask = _mm256_extractf128_si256(*(__m256i*)&mask, 1); \ - lo_max_j = _mm_andnot_si128(lo_mask, lo_max_j); \ - hi_max_j = _mm_andnot_si128(hi_mask, hi_max_j); \ - lo_mask = _mm_and_si128(lo_mask, _mm_set1_epi32(i)); \ - hi_mask = _mm_and_si128(hi_mask, _mm_set1_epi32(i)); \ - lo_max_j = _mm_or_si128(lo_mask, lo_max_j); \ - hi_max_j = _mm_or_si128(hi_mask, hi_max_j); \ - max_j = _mm256_insertf128_si256(max_j, lo_max_j, 0); \ - max_j = _mm256_insertf128_si256(max_j, hi_max_j, 1); \ - /* AVX done*/ \ - /* Update the max_score value.*/ \ - max_score = _mm256_max_ps(max_score, score_v); \ - trans_offset += this->num_; \ - } \ - UPDATE_ALPHA(YMM_FLOAT_BLOCK) \ - } \ - seq_offset += this->num_; \ - } \ - } - -#define INTRIAVX2_FLOAT(isa, block) \ - template <> \ - CRFDecodeKernelImpl::CRFDecodeKernelImpl(int tag_num) \ - : CRFDecodeKernel() { \ - this->num_ = tag_num; \ - this->end_ = this->num_ / YMM_FLOAT_BLOCK; \ - this->rest_ = this->num_ % YMM_FLOAT_BLOCK; \ - } \ - template <> \ - void CRFDecodeKernelImpl::Compute( \ - const int seq_len, const float* x, const float* w, float* alpha, \ - int* track) const { \ - INIT_ALPHA(YMM_FLOAT_BLOCK) \ - /* Use the column-major strategy to get the location of maximum score.*/ \ - int seq_offset = 0; \ - constexpr int state_trans_base_idx = 2; \ - for (int k = 1; k < seq_len; ++k) { \ - int j_offset = 0; \ - for (int j = 0; j <= this->end_; ++j) { \ - /* Initialize the variables of maximum score and location.*/ \ - __m256 max_score = _mm256_set1_ps(-std::numeric_limits::max()); \ - __m256i max_j = _mm256_set1_epi32(0); \ - /* Calculate the offset of transition_weights.*/ \ - int trans_offset = state_trans_base_idx * this->num_ + j_offset; \ - for (int i = 0; i < this->num_; ++i) { \ - /* Initalize the content of alpha variable with related offset.*/ \ - __m256 alpha_content = _mm256_broadcast_ss(alpha + seq_offset + i); \ - /* Obtain the content of weights from un-aligned address.*/ \ - __m256 w_content = _mm256_loadu_ps(w + trans_offset); \ - __m256 score_v = _mm256_add_ps(alpha_content, w_content); \ - __m256 mask = _mm256_cmp_ps(score_v, max_score, _CMP_GT_OS); \ - /* According to the mask value, update the index of the max_score.*/ \ - /* AVX2 instructions.*/ \ - max_j = _mm256_or_si256( \ - _mm256_andnot_si256((__m256i)mask, max_j), \ - _mm256_and_si256((__m256i)mask, _mm256_set1_epi32(i))); \ - /* Update the max_score value.*/ \ - max_score = _mm256_max_ps(max_score, score_v); \ - trans_offset += this->num_; \ - } \ - UPDATE_ALPHA(YMM_FLOAT_BLOCK) \ - } \ - seq_offset += this->num_; \ - } \ - } - -#define INTRIAVX512_FLOAT(block) \ - template <> \ - CRFDecodeKernelImpl::CRFDecodeKernelImpl( \ - int tag_num) \ - : CRFDecodeKernel() { \ - this->num_ = tag_num; \ - this->end_ = this->num_ / ZMM_FLOAT_BLOCK; \ - this->rest_ = this->num_ % ZMM_FLOAT_BLOCK; \ - } \ - template <> \ - void CRFDecodeKernelImpl::Compute( \ - const int seq_len, const float* x, const float* w, float* alpha, \ - int* track) const { \ - INIT_ALPHA(ZMM_FLOAT_BLOCK) \ - /* Use the column-major strategy to get the location of maximum score.*/ \ - int seq_offset = 0; \ - constexpr int state_trans_base_idx = 2; \ - for (int k = 1; k < seq_len; ++k) { \ - int j_offset = 0; \ - for (int j = 0; j <= this->end_; ++j) { \ - /* Initialize the variables of maximum score and location.*/ \ - __m512 max_score = _mm512_set1_ps(-std::numeric_limits::max()); \ - __m512i max_j = _mm512_setzero_si512(); \ - /* Calculate the offset of transition_weights.*/ \ - int trans_offset = state_trans_base_idx * this->num_ + j_offset; \ - for (int i = 0; i < this->num_; ++i) { \ - /* Initalize the content of alpha variable with related offset.*/ \ - __m512 alpha_content = _mm512_set1_ps(*(alpha + seq_offset + i)); \ - /* Obtain the content of weights from un-aligned address.*/ \ - __m512 w_content = _mm512_loadu_ps(w + trans_offset); \ - __m512 score_v = _mm512_add_ps(alpha_content, w_content); \ - __mmask16 mask = _mm512_cmp_ps_mask(score_v, max_score, _CMP_GT_OS); \ - /* AVX512 instructions.*/ \ - max_j = _mm512_mask_set1_epi32(max_j, mask, i); \ - /* Update the max_score value.*/ \ - max_score = _mm512_max_ps(max_score, score_v); \ - trans_offset += this->num_; \ - } \ - /* Update the alpha and track values.*/ \ - __m512 x_content = \ - _mm512_loadu_ps(x + seq_offset + this->num_ + j_offset); \ - max_score = _mm512_add_ps(max_score, x_content); \ - _mm512_storeu_ps(alpha + seq_offset + this->num_ + j_offset, \ - max_score); \ - _mm512_storeu_si512(reinterpret_cast<__m512i*>(track + seq_offset + \ - this->num_ + j_offset), \ - max_j); \ - /* Calculate the offset of next step*/ \ - j_offset += ZMM_FLOAT_BLOCK; \ - if (j == this->end_ - 1) { \ - if (this->rest_ > 0) { \ - j_offset += last_offset; \ - } else { \ - break; \ - } \ - } \ - } \ - seq_offset += this->num_; \ - } \ - } - -#ifdef __AVX__ -INTRIAVX_FLOAT(kEQ8); -INTRIAVX_FLOAT(kGT8LT16); -INTRIAVX_FLOAT(kEQ16); -INTRIAVX_FLOAT(kGT16); -#endif -#ifdef __AVX2__ -INTRIAVX2_FLOAT(platform::avx2, kEQ8); -INTRIAVX2_FLOAT(platform::avx2, kGT8LT16); -INTRIAVX2_FLOAT(platform::avx2, kEQ16); -INTRIAVX2_FLOAT(platform::avx2, kGT16); -#endif -#ifdef __AVX512F__ -INTRIAVX2_FLOAT(platform::avx512f, kEQ8); -INTRIAVX2_FLOAT(platform::avx512f, kGT8LT16); -INTRIAVX512_FLOAT(kEQ16); -INTRIAVX512_FLOAT(kGT16); -#endif - -#undef INTRIAVX512_FLOAT -#undef INTRIAVX2_FLOAT -#undef INTRIAVX_FLOAT -#undef INIT_ALPHA -#undef UPDATE_ALPHA - -REGISTER_JITKERNEL_DEPRECATED(crf_decode, CRFDecodeKernel); - -} // namespace jitkernel -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/jit_kernel_exp.cc b/paddle/fluid/operators/math/jit_kernel_exp.cc deleted file mode 100644 index 1f97ed1e62c..00000000000 --- a/paddle/fluid/operators/math/jit_kernel_exp.cc +++ /dev/null @@ -1,195 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/math/jit_kernel.h" -#include -#include "paddle/fluid/operators/math/jit_kernel_macro.h" -#include "paddle/fluid/operators/math/jit_kernel_refer.h" - -#ifdef PADDLE_WITH_XBYAK -#include "paddle/fluid/operators/math/jit_code.h" -#endif - -#ifdef PADDLE_WITH_MKLML -#include "paddle/fluid/platform/dynload/mklml.h" -#endif - -namespace paddle { -namespace operators { -namespace math { -namespace jitkernel { - -/* VExp JitKernel */ -template -class VExpKernelImpl : public VExpKernel { - public: - JITKERNEL_DECLARE_STATIC_FUNC; - explicit VExpKernelImpl(int d) : VExpKernel() { -#ifdef PADDLE_WITH_XBYAK - if (useJIT(d)) { - size_t sz = 96 + d / YMM_FLOAT_BLOCK * 70 * 8; - jitcode_.reset(new gen::VActJitCode(d, gen::operand_type::exp, - sz > 4096 ? sz : 4096)); - this->Compute = jitcode_->getCode(); - return; - } -#endif -#ifdef PADDLE_WITH_MKLML - if (useMKL(d)) { - this->Compute = VExpMKL; - return; - } -#endif - this->Compute = refer::VExp; - } - -#ifdef PADDLE_WITH_XBYAK - - private: - std::unique_ptr jitcode_{nullptr}; -#endif -}; - -#ifdef PADDLE_WITH_XBYAK -template <> -bool VExpKernelImpl::useJIT(int d) { - return gen::VActJitCode::init(d, gen::operand_type::exp); -} -#endif - -#ifdef PADDLE_WITH_MKLML -template <> -bool VExpKernelImpl::useMKL(int d) { - return d > 512; -} - -template <> -bool VExpKernelImpl::useMKL(int d) { - return true; -} - -#endif - -/* VSigmoid JitKernel */ -template -class VSigmoidKernelImpl : public VSigmoidKernel { - public: - JITKERNEL_DECLARE_STATIC_FUNC; - explicit VSigmoidKernelImpl(int d) : VSigmoidKernel() { -#ifdef PADDLE_WITH_XBYAK - if (useJIT(d)) { - size_t sz = 96 + d / YMM_FLOAT_BLOCK * 82 * 8; - jitcode_.reset(new gen::VActJitCode(d, gen::operand_type::sigmoid, - sz > 4096 ? sz : 4096)); - this->Compute = jitcode_->getCode(); - return; - } -#endif - -#ifdef PADDLE_WITH_MKLML - // strictly it's a better impl with MKL, then is refer - if (useMKL(d)) { - this->Compute = VSigmoidMKL; - return; - } -#endif - this->Compute = refer::VSigmoid; - } - -#ifdef PADDLE_WITH_XBYAK - - private: - std::unique_ptr jitcode_{nullptr}; -#endif -}; - -#ifdef PADDLE_WITH_XBYAK -template <> -bool VSigmoidKernelImpl::useJIT(int d) { - return gen::VActJitCode::init(d, gen::operand_type::sigmoid); -} -#endif - -#ifdef PADDLE_WITH_MKLML -template <> -bool VSigmoidKernelImpl::useMKL(int d) { - return d > 512; -} - -template <> -bool VSigmoidKernelImpl::useMKL(int d) { - return true; -} -#endif - -/* VTanh JitKernel */ -template -class VTanhKernelImpl : public VTanhKernel { - public: - JITKERNEL_DECLARE_STATIC_FUNC; - explicit VTanhKernelImpl(int d) : VTanhKernel() { -#ifdef PADDLE_WITH_XBYAK - if (useJIT(d)) { - size_t sz = 96 + d / YMM_FLOAT_BLOCK * 84 * 8; - jitcode_.reset(new gen::VActJitCode(d, gen::operand_type::tanh, - sz > 4096 ? sz : 4096)); - this->Compute = jitcode_->getCode(); - return; - } -#endif - -#ifdef PADDLE_WITH_MKLML - // strictly it's a better impl with MKL, then is refer - if (useMKL(d)) { - this->Compute = VTanhMKL; - return; - } -#endif - this->Compute = refer::VTanh; - } - -#ifdef PADDLE_WITH_XBYAK - - private: - std::unique_ptr jitcode_{nullptr}; -#endif -}; - -#ifdef PADDLE_WITH_XBYAK -template <> -bool VTanhKernelImpl::useJIT(int d) { - return gen::VActJitCode::init(d, gen::operand_type::tanh); -} -#endif - -#ifdef PADDLE_WITH_MKLML -template <> -bool VTanhKernelImpl::useMKL(int d) { - return d > 512; -} - -template <> -bool VTanhKernelImpl::useMKL(int d) { - return true; -} -#endif - -REGISTER_JITKERNEL(vexp, VExpKernel); -REGISTER_JITKERNEL(vsigmoid, VSigmoidKernel); -REGISTER_JITKERNEL(vtanh, VTanhKernel); - -} // namespace jitkernel -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/jit_kernel_impl.h b/paddle/fluid/operators/math/jit_kernel_impl.h deleted file mode 100644 index 025343dfad4..00000000000 --- a/paddle/fluid/operators/math/jit_kernel_impl.h +++ /dev/null @@ -1,34 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include - -namespace paddle { -namespace operators { -namespace math { -namespace jitkernel { - -#define SIGMOID_THRESHOLD_MIN -40.0 -#define SIGMOID_THRESHOLD_MAX 13.0 -#define EXP_MAX_INPUT 40.0 -#define XMM_FLOAT_BLOCK 4 -#define YMM_FLOAT_BLOCK 8 -#define ZMM_FLOAT_BLOCK 16 - -} // namespace jitkernel -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/jit_kernel_layer_norm.cc b/paddle/fluid/operators/math/jit_kernel_layer_norm.cc deleted file mode 100644 index e21092037a2..00000000000 --- a/paddle/fluid/operators/math/jit_kernel_layer_norm.cc +++ /dev/null @@ -1,239 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at -http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -#include "paddle/fluid/operators/math/jit_kernel.h" -#include -#include -#include -#include "paddle/fluid/operators/math/jit_kernel_macro.h" - -namespace paddle { -namespace operators { -namespace math { -namespace jitkernel { - -/* Layer Norm JitKernel */ -template -class LayerNormKernelImpl : public LayerNormKernel { - public: - explicit LayerNormKernelImpl(int right) : LayerNormKernel() { - this->num_ = right; - } - - void Compute(T* x, T* out, T* mean, T* var, const T* scale, const T* bias, - int height, const float epsilon) const override { - // get mean - for (int i = 0; i < height; i++) { - T sum = 0.0; - int offset = i * this->num_; - for (int j = 0; j < this->num_; j++) { - sum += x[offset + j]; - } - mean[i] = sum / this->num_; - } - - // get variance - for (int i = 0; i < height; i++) { - T sum = 0.0; - int offset = i * this->num_; - for (int j = 0; j < this->num_; j++) { - sum += (x[offset + j] - mean[i]) * (x[offset + j] - mean[i]); - } - var[i] = sum / this->num_; - } - - for (int i = 0; i < height; i++) { - int offset = i * this->num_; - T sqrt_var = sqrt(var[i] + (T)epsilon); - for (int j = 0; j < this->num_; j++) { - out[offset + j] = (x[offset + j] - mean[i]) / sqrt_var; - } - } - if (scale) { - for (int i = 0; i < height; i++) { - int offset = i * this->num_; - for (int j = 0; j < this->num_; j++) { - out[offset + j] *= scale[j]; - } - } - } - - if (bias) { - for (int i = 0; i < height; i++) { - int offset = i * this->num_; - for (int j = 0; j < this->num_; j++) { - out[offset + j] += bias[j]; - } - } - } - } -}; - -#define INTRIAVX_FLOAT(isa, jit_block) \ - template <> \ - LayerNormKernelImpl::LayerNormKernelImpl(int right) \ - : LayerNormKernel() { \ - this->num_ = right; \ - this->rest_ = this->num_ % YMM_FLOAT_BLOCK; \ - this->end_ = this->num_ - this->rest_; \ - } \ - template <> \ - void LayerNormKernelImpl::Compute( \ - float* x, float* out, float* mean, float* var, const float* scale, \ - const float* bias, int height, const float epsilon) const { \ - __m256 sum; \ - __m256 mean_vec, var_vec; \ - __m128 hi, lo; \ - __m256 tmp; \ - size_t offset; \ - size_t j; \ - size_t block = YMM_FLOAT_BLOCK; \ - __m256 reverse_num_vec = \ - _mm256_div_ps(_mm256_set1_ps(1.0), _mm256_set1_ps(this->num_)); \ - __m256 epsilon_vec = _mm256_set1_ps(epsilon); \ - int rest_mask = \ - ((-1) & (~((~0U) >> (sizeof(int) * 8 - (YMM_FLOAT_BLOCK - rest_))))) & \ - 0x0ff; \ - __m256i mask_vec = _mm256_set_epi32( \ - rest_mask & 0x80 ? 0xffffffff : 0, rest_mask & 0x40 ? 0xffffffff : 0, \ - rest_mask & 0x20 ? 0xffffffff : 0, rest_mask & 0x10 ? 0xffffffff : 0, \ - rest_mask & 0x8 ? 0xffffffff : 0, rest_mask & 0x4 ? 0xffffffff : 0, \ - rest_mask & 0x2 ? 0xffffffff : 0, rest_mask & 0x1 ? 0xffffffff : 0); \ - \ - for (int i = 0; i < height; ++i) { \ - offset = i * this->num_; \ - \ - /* get mean */ \ - sum = _mm256_setzero_ps(); \ - for (j = offset; j < end_ + offset; j += block) { \ - sum = _mm256_add_ps(sum, _mm256_loadu_ps((const float*)x + j)); \ - } \ - if (rest_ != 0) { \ - j = offset + this->num_ - block; \ - tmp = _mm256_loadu_ps((const float*)x + j); \ - tmp = _mm256_blendv_ps(_mm256_setzero_ps(), tmp, *(__m256*)&mask_vec); \ - sum = _mm256_add_ps(sum, tmp); \ - } \ - hi = _mm256_extractf128_ps(sum, 1); \ - lo = _mm256_extractf128_ps(sum, 0); \ - sum = _mm256_add_ps( \ - sum, _mm256_insertf128_ps( \ - _mm256_insertf128_ps(_mm256_setzero_ps(), hi, 0), lo, 1)); \ - sum = _mm256_hadd_ps(sum, sum); \ - sum = _mm256_hadd_ps(sum, sum); \ - mean_vec = _mm256_mul_ps(sum, reverse_num_vec); \ - mean[i] = *reinterpret_cast(&mean_vec); \ - \ - /* get variance */ \ - sum = _mm256_setzero_ps(); \ - for (j = offset; j < end_ + offset; j += block) { \ - tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); \ - tmp = _mm256_mul_ps(tmp, tmp); \ - sum = _mm256_add_ps(sum, tmp); \ - } \ - if (rest_ != 0) { \ - j = offset + this->num_ - block; \ - tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); \ - tmp = _mm256_mul_ps(tmp, tmp); \ - tmp = _mm256_blendv_ps(_mm256_setzero_ps(), tmp, *(__m256*)&mask_vec); \ - sum = _mm256_add_ps(sum, tmp); \ - } \ - hi = _mm256_extractf128_ps(sum, 1); \ - lo = _mm256_extractf128_ps(sum, 0); \ - sum = _mm256_add_ps( \ - sum, _mm256_insertf128_ps( \ - _mm256_insertf128_ps(_mm256_setzero_ps(), hi, 0), lo, 1)); \ - sum = _mm256_hadd_ps(sum, sum); \ - sum = _mm256_hadd_ps(sum, sum); \ - var_vec = _mm256_mul_ps(sum, reverse_num_vec); \ - var[i] = *reinterpret_cast(&var_vec); \ - \ - /* get x_norm and calculate output*/ \ - for (j = offset; j < end_ + offset; j += block) { \ - tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); \ - tmp = _mm256_div_ps( \ - tmp, _mm256_sqrt_ps(_mm256_add_ps(var_vec, epsilon_vec))); \ - _mm256_storeu_ps(reinterpret_cast(out) + j, tmp); \ - } \ - if (rest_ != 0) { \ - j = offset + num_ - block; \ - tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); \ - tmp = _mm256_div_ps( \ - tmp, _mm256_sqrt_ps(_mm256_add_ps(var_vec, epsilon_vec))); \ - _mm256_storeu_ps(reinterpret_cast(out) + j, tmp); \ - } \ - \ - if (scale) { \ - if (rest_ != 0) { \ - j = offset + this->num_ - block; \ - tmp = _mm256_loadu_ps((const float*)out + j); \ - } \ - for (j = offset; j < end_ + offset; j += block) { \ - _mm256_storeu_ps( \ - reinterpret_cast(out) + j, \ - _mm256_mul_ps( \ - _mm256_loadu_ps((const float*)out + j), \ - _mm256_loadu_ps((const float*)scale + j - offset))); \ - } \ - if (rest_ != 0) { \ - j = offset + this->num_ - block; \ - _mm256_storeu_ps( \ - reinterpret_cast(out) + j, \ - _mm256_mul_ps( \ - tmp, _mm256_loadu_ps((const float*)scale + j - offset))); \ - } \ - } \ - \ - if (bias) { \ - if (rest_ != 0) { \ - j = offset + this->num_ - block; \ - tmp = _mm256_loadu_ps((const float*)out + j); \ - } \ - for (j = offset; j < end_ + offset; j += block) { \ - _mm256_storeu_ps( \ - reinterpret_cast(out) + j, \ - _mm256_add_ps( \ - _mm256_loadu_ps((const float*)out + j), \ - _mm256_loadu_ps((const float*)bias + j - offset))); \ - } \ - if (rest_ != 0) { \ - j = offset + this->num_ - block; \ - _mm256_storeu_ps( \ - reinterpret_cast(out) + j, \ - _mm256_add_ps( \ - tmp, _mm256_loadu_ps((const float*)bias + j - offset))); \ - } \ - } \ - } \ - } - -#ifdef __AVX__ -INTRIAVX_FLOAT(platform::avx, kEQ8); -INTRIAVX_FLOAT(platform::avx, kGT8LT16); -INTRIAVX_FLOAT(platform::avx, kEQ16); -INTRIAVX_FLOAT(platform::avx, kGT16); -INTRIAVX_FLOAT(platform::avx2, kEQ8); -INTRIAVX_FLOAT(platform::avx2, kGT8LT16); -INTRIAVX_FLOAT(platform::avx2, kEQ16); -INTRIAVX_FLOAT(platform::avx2, kGT16); -INTRIAVX_FLOAT(platform::avx512f, kEQ8); -INTRIAVX_FLOAT(platform::avx512f, kGT8LT16); -INTRIAVX_FLOAT(platform::avx512f, kEQ16); -INTRIAVX_FLOAT(platform::avx512f, kGT16); -#endif - -#undef INTRIAVX_FLOAT - -REGISTER_JITKERNEL_DEPRECATED(layer_norm, LayerNormKernel); - -} // namespace jitkernel -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/jit_kernel_macro.h b/paddle/fluid/operators/math/jit_kernel_macro.h deleted file mode 100644 index 4dba3b56810..00000000000 --- a/paddle/fluid/operators/math/jit_kernel_macro.h +++ /dev/null @@ -1,179 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include "paddle/fluid/platform/cpu_info.h" -#include "paddle/fluid/platform/enforce.h" - -namespace paddle { -namespace operators { -namespace math { -namespace jitkernel { - -#define JITKERNEL_DECLARE_STATIC_FUNC \ - static inline std::string name(int d) { \ - PADDLE_THROW("DType should be either float or double"); \ - } \ - static inline bool useJIT(int d) { return false; } \ - static inline bool useMKL(int d) { return false; } - -#define JITKERNEL_DEFINE_NAME(ker_key, ker_class) \ - template <> \ - std::string ker_class##Impl::name(int d) { \ - std::string key(#ker_key "f"); \ - if (useJIT(d)) { \ - /* only jit code need record d*/ \ - return key + "jit" + std::to_string(d); \ - } else if (useMKL(d)) { \ - return key + "mkl"; \ - } else { \ - return key + "any"; \ - } \ - } \ - template <> \ - std::string ker_class##Impl::name(int d) { \ - std::string key(#ker_key "d"); \ - /* jit code do not support double yet*/ \ - if (useMKL(d)) { \ - return key + "mkl"; \ - } else { \ - return key + "any"; \ - } \ - } - -#define JITKERNEL_DECLARE(ker_class, ker_dtype) \ - template <> \ - std::shared_ptr> \ - KernelPool::Get, int>(int d) - -#define JITKERNEL_FIND_KEY(ker_class, ker_dtype) \ - std::string key = ker_class##Impl::name(d) - -#define JITKERNEL_IMPL(ker_class, ker_dtype) \ - p = std::dynamic_pointer_cast>( \ - std::make_shared>(d)) - -#define REGISTER_JITKERNEL_WITH_DTYPE(ker_class, ker_dtype, marco_declare, \ - macro_find_key, macro_impl) \ - marco_declare(ker_class, ker_dtype) { \ - macro_find_key(ker_class, ker_dtype); \ - if (kers_.find(key) == kers_.end()) { \ - std::shared_ptr> p; \ - macro_impl(ker_class, ker_dtype); \ - kers_.insert({key, std::dynamic_pointer_cast(p)}); \ - return p; \ - } \ - return std::dynamic_pointer_cast>( \ - kers_.at(key)); \ - } - -#define REGISTER_JITKERNEL_ARGS(ker_key, ker_class, marco_define_name, \ - marco_declare, macro_find_key, macro_impl) \ - marco_define_name(ker_key, ker_class); \ - REGISTER_JITKERNEL_WITH_DTYPE(ker_class, float, marco_declare, \ - macro_find_key, macro_impl); \ - REGISTER_JITKERNEL_WITH_DTYPE(ker_class, double, marco_declare, \ - macro_find_key, macro_impl) - -#define REGISTER_JITKERNEL(ker_key, ker_class) \ - REGISTER_JITKERNEL_ARGS(ker_key, ker_class, JITKERNEL_DEFINE_NAME, \ - JITKERNEL_DECLARE, JITKERNEL_FIND_KEY, \ - JITKERNEL_IMPL) - -// TODO(TJ): below defines are deprecated, would be remove recently -#define SEARCH_BLOCK(macro_, ker, dtype, isa) \ - if (d < YMM_FLOAT_BLOCK) { \ - macro_(ker, dtype, isa, kLT8); \ - } else if (d == YMM_FLOAT_BLOCK) { \ - macro_(ker, dtype, isa, kEQ8); \ - } else if (d > YMM_FLOAT_BLOCK && d < ZMM_FLOAT_BLOCK) { \ - macro_(ker, dtype, isa, kGT8LT16); \ - } else if (d == ZMM_FLOAT_BLOCK) { \ - macro_(ker, dtype, isa, kEQ16); \ - } else { \ - macro_(ker, dtype, isa, kGT16); \ - } - -#define SEARCH_ISA_BLOCK(macro_, ker, dtype) \ - if (platform::MayIUse(platform::avx512f)) { \ - SEARCH_BLOCK(macro_, ker, dtype, platform::avx512f); \ - } else if (platform::MayIUse(platform::avx2)) { \ - SEARCH_BLOCK(macro_, ker, dtype, platform::avx2); \ - } else if (platform::MayIUse(platform::avx)) { \ - SEARCH_BLOCK(macro_, ker, dtype, platform::avx); \ - } else { \ - SEARCH_BLOCK(macro_, ker, dtype, platform::isa_any); \ - } - -#define JITKERNEL_KEY(ker_key, dtype_key) \ - #ker_key #dtype_key + std::to_string(d) - -#define JITKERNEL_NEW_IMPL_DEPRECATED(ker, dtype, isa, k) \ - p = std::dynamic_pointer_cast>( \ - std::make_shared>(d)) - -#define JITKERNEL_WITH_DTYPE_DEPRECATED(ker_key, ker_class, ker_dtype, \ - dtype_key, marco_declare, macro_key, \ - macro_impl) \ - marco_declare(ker_class, ker_dtype) { \ - std::string key = macro_key(ker_key, dtype_key); \ - if (kers_.find(key) == kers_.end()) { \ - std::shared_ptr> p; \ - SEARCH_ISA_BLOCK(macro_impl, ker_class, ker_dtype); \ - kers_.insert({key, std::dynamic_pointer_cast(p)}); \ - return p; \ - } \ - return std::dynamic_pointer_cast>( \ - kers_.at(key)); \ - } - -#define REGISTER_JITKERNEL_DEPRECATED(ker_key, ker_class) \ - JITKERNEL_WITH_DTYPE_DEPRECATED(ker_key, ker_class, float, f, \ - JITKERNEL_DECLARE, JITKERNEL_KEY, \ - JITKERNEL_NEW_IMPL_DEPRECATED); \ - JITKERNEL_WITH_DTYPE_DEPRECATED(ker_key, ker_class, double, d, \ - JITKERNEL_DECLARE, JITKERNEL_KEY, \ - JITKERNEL_NEW_IMPL_DEPRECATED) - -#define REGISTER_JITKERNEL_ARGS_DEPRECATED(ker_key, ker_class, marco_declare, \ - macro_key, macro_impl) \ - JITKERNEL_WITH_DTYPE_DEPRECATED(ker_key, ker_class, float, f, marco_declare, \ - macro_key, macro_impl); \ - JITKERNEL_WITH_DTYPE_DEPRECATED(ker_key, ker_class, double, d, \ - marco_declare, macro_key, macro_impl) - -#define FOR_EACH_ISA(macro_, block) \ - macro_(platform::avx512f, block); \ - macro_(platform::avx2, block); \ - macro_(platform::avx, block); \ - macro_(platform::isa_any, block) - -#define FOR_EACH_BLOCK(macro_, isa) \ - macro_(isa, kLT8); \ - macro_(isa, kEQ8); \ - macro_(isa, kGT8LT16); \ - macro_(isa, kEQ16); \ - macro_(isa, kGT16) - -#define FOR_EACH_ISA_BLOCK(macro_) \ - FOR_EACH_BLOCK(macro_, platform::avx512f); \ - FOR_EACH_BLOCK(macro_, platform::avx2); \ - FOR_EACH_BLOCK(macro_, platform::avx); \ - FOR_EACH_BLOCK(macro_, platform::isa_any) - -} // namespace jitkernel -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/jit_kernel_refer.h b/paddle/fluid/operators/math/jit_kernel_refer.h deleted file mode 100644 index d49fc935dc5..00000000000 --- a/paddle/fluid/operators/math/jit_kernel_refer.h +++ /dev/null @@ -1,29 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include -#include "paddle/fluid/operators/math/jit_kernel_impl.h" -#include "paddle/fluid/platform/enforce.h" - -namespace paddle { -namespace operators { -namespace math { -namespace jitkernel { -namespace refer {} // namespace refer -} // namespace jitkernel -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/jit_kernel_rnn.cc b/paddle/fluid/operators/math/jit_kernel_rnn.cc deleted file mode 100644 index 2db3274a456..00000000000 --- a/paddle/fluid/operators/math/jit_kernel_rnn.cc +++ /dev/null @@ -1,263 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/math/jit_kernel.h" -#include -#include "paddle/fluid/operators/math/jit_kernel_macro.h" -#include "paddle/fluid/operators/math/jit_kernel_refer.h" -#include "paddle/fluid/platform/enforce.h" -#include "paddle/fluid/platform/macros.h" - -#ifdef PADDLE_WITH_XBYAK -#include "paddle/fluid/operators/math/jit_code.h" -#endif - -namespace paddle { -namespace operators { -namespace math { -namespace jitkernel { - -/* LSTM JitKernel */ -template -class LSTMKernelImpl : public LSTMKernel { - public: - static inline std::string name(const lstm_attr_t& attr) { - PADDLE_THROW("DType should be either float or double"); - } - static inline bool useJIT(int d) { return false; } - static inline bool useMKL(int d) { return false; } - explicit LSTMKernelImpl(const lstm_attr_t& attr) : LSTMKernel() { -#ifdef PADDLE_WITH_XBYAK - if (useJIT(attr.d)) { - size_t sz = 96 + attr.d / YMM_FLOAT_BLOCK * 90 * 4 * 8; - jitcode0_.reset(new gen::LSTMJitCode(false, attr, sz > 4096 ? sz : 4096)); - this->ComputeCtHt = - jitcode0_->getCode(); - - jitcode1_.reset(new gen::LSTMJitCode(true, attr, sz > 4096 ? sz : 4096)); - this->ComputeC1H1 = - jitcode1_->getCode(); - return; - } -#endif - - this->ComputeCtHt = refer::LSTMCtHt; - this->ComputeC1H1 = refer::LSTMC1H1; - } - -#ifdef PADDLE_WITH_XBYAK - - private: - std::unique_ptr jitcode0_{nullptr}, jitcode1_{nullptr}; -#endif -}; - -#ifdef PADDLE_WITH_XBYAK -template <> -bool LSTMKernelImpl::useJIT(int d) { - return gen::LSTMJitCode::init(d); -} -#endif - -/* Peephole JitKernel */ -template -class PeepholeKernelImpl : public LSTMKernel { - public: - static inline std::string name(const lstm_attr_t& attr) { - PADDLE_THROW("DType should be either float or double"); - } - static inline bool useJIT(int d) { return false; } - static inline bool useMKL(int d) { return false; } - explicit PeepholeKernelImpl(const lstm_attr_t& attr) : LSTMKernel() { -#ifdef PADDLE_WITH_XBYAK - if (useJIT(attr.d)) { - size_t sz = 96 + attr.d / YMM_FLOAT_BLOCK * 96 * 4 * 8; - jitcode0_.reset(new gen::LSTMJitCode(false, attr, sz > 4096 ? sz : 4096)); - this->ComputeCtHt = - jitcode0_->getCode(); - - jitcode1_.reset(new gen::LSTMJitCode(true, attr, sz > 4096 ? sz : 4096)); - this->ComputeC1H1 = - jitcode1_->getCode(); - return; - } -#endif - - this->ComputeCtHt = refer::LSTMCtHt; - this->ComputeC1H1 = refer::LSTMC1H1; - } - -#ifdef PADDLE_WITH_XBYAK - - private: - std::unique_ptr jitcode0_{nullptr}, jitcode1_{nullptr}; -#endif -}; - -#ifdef PADDLE_WITH_XBYAK -template <> -bool PeepholeKernelImpl::useJIT(int d) { - return gen::LSTMJitCode::init(d); -} -#endif - -#define JITKERNEL_DEFINE_NAME_LSTM(ker_key, ker_class) \ - template <> \ - std::string ker_class##Impl::name(const lstm_attr_t& attr) { \ - std::string key(#ker_key "f"); \ - key += (attr.act_gate + attr.act_cand + attr.act_cell + \ - (attr.use_peephole ? "p" : "n")); \ - if (useJIT(attr.d)) { \ - /* only jit code need record d*/ \ - return key + "jit" + std::to_string(attr.d); \ - } else if (useMKL(attr.d)) { \ - return key + "mkl"; \ - } else { \ - return key + "any"; \ - } \ - } \ - template <> \ - std::string ker_class##Impl::name(const lstm_attr_t& attr) { \ - std::string key(#ker_key "d"); \ - /* jit code do not support double yet*/ \ - if (useMKL(attr.d)) { \ - return key + "mkl"; \ - } else { \ - return key + "any"; \ - } \ - } - -#define JITKERNEL_DECLARE_LSTM(ker_class, ker_dtype) \ - template <> \ - std::shared_ptr> \ - KernelPool::Get, const lstm_attr_t&>( \ - const lstm_attr_t& attr) - -#define JITKERNEL_FIND_KEY_LSTM(ker_class, ker_dtype) \ - std::string key = ker_class##Impl::name(attr) - -#define JITKERNEL_LSTM_IMPL(ker, dtype) \ - if (attr.use_peephole) { \ - p = std::dynamic_pointer_cast>( \ - std::make_shared>(attr)); \ - } else { \ - p = std::dynamic_pointer_cast>( \ - std::make_shared>(attr)); \ - } - -REGISTER_JITKERNEL_ARGS(lstm, LSTMKernel, JITKERNEL_DEFINE_NAME_LSTM, - JITKERNEL_DECLARE_LSTM, JITKERNEL_FIND_KEY_LSTM, - JITKERNEL_LSTM_IMPL); - -#undef JITKERNEL_LSTM_IMPL -#undef JITKERNEL_FIND_KEY_LSTM -#undef JITKERNEL_DECLARE_LSTM -#undef JITKERNEL_DEFINE_NAME_LSTM - -/* GRU JitKernel */ -template -class GRUKernelImpl : public GRUKernel { - public: - static inline std::string name(const gru_attr_t& attr) { - PADDLE_THROW("DType should be either float or double"); - } - static inline bool useJIT(int d) { return false; } - static inline bool useMKL(int d) { return false; } - explicit GRUKernelImpl(const gru_attr_t& attr) : GRUKernel() { -#ifdef PADDLE_WITH_XBYAK - if (useJIT(attr.d)) { - size_t sz = 96 + attr.d / YMM_FLOAT_BLOCK * 96 * 2 * 8; - jitcode0_.reset(new gen::GRUJitCode(0, attr, sz > 4096 ? sz : 4096)); - this->ComputeH1 = - jitcode0_->getCode(); - - jitcode1_.reset(new gen::GRUJitCode(1, attr, sz > 4096 ? sz : 4096)); - this->ComputeHtPart1 = - jitcode1_->getCode(); - - jitcode2_.reset(new gen::GRUJitCode(2, attr, sz > 4096 ? sz : 4096)); - this->ComputeHtPart2 = - jitcode2_->getCode(); - return; - } -#endif - this->ComputeH1 = refer::GRUH1; - this->ComputeHtPart1 = refer::GRUHtPart1; - this->ComputeHtPart2 = refer::GRUHtPart2; - } -#ifdef PADDLE_WITH_XBYAK - - private: - std::unique_ptr jitcode0_{nullptr}, jitcode1_{nullptr}, - jitcode2_{nullptr}; -#endif -}; - -#ifdef PADDLE_WITH_XBYAK -template <> -bool GRUKernelImpl::useJIT(int d) { - return gen::GRUJitCode::init(d); -} -#endif - -#define JITKERNEL_DEFINE_NAME_GRU(ker_key, ker_class) \ - template <> \ - std::string ker_class##Impl::name(const gru_attr_t& attr) { \ - std::string key(#ker_key "f"); \ - key += (attr.act_gate + attr.act_cand); \ - if (useJIT(attr.d)) { \ - /* only jit code need record d*/ \ - return key + "jit" + std::to_string(attr.d); \ - } else if (useMKL(attr.d)) { \ - return key + "mkl"; \ - } else { \ - return key + "any"; \ - } \ - } \ - template <> \ - std::string ker_class##Impl::name(const gru_attr_t& attr) { \ - std::string key(#ker_key "d"); \ - /* jit code do not support double yet*/ \ - if (useMKL(attr.d)) { \ - return key + "mkl"; \ - } else { \ - return key + "any"; \ - } \ - } - -#define JITKERNEL_DECLARE_GRU(ker_class, ker_dtype) \ - template <> \ - std::shared_ptr> \ - KernelPool::Get, const gru_attr_t&>( \ - const gru_attr_t& attr) - -#define JITKERNEL_FIND_KEY_GRU(ker_class, ker_dtype) \ - std::string key = ker_class##Impl::name(attr) - -#define JITKERNEL_GRU_IMPL(ker, dtype) \ - p = std::dynamic_pointer_cast>( \ - std::make_shared>(attr)); - -REGISTER_JITKERNEL_ARGS(gru, GRUKernel, JITKERNEL_DEFINE_NAME_GRU, - JITKERNEL_DECLARE_GRU, JITKERNEL_FIND_KEY_GRU, - JITKERNEL_GRU_IMPL); - -#undef JITKERNEL_GRU_IMPL -#undef JITKERNEL_FIND_KEY_GRU -#undef JITKERNEL_DECLARE_GRU -#undef JITKERNEL_DEFINE_NAME_GRU -} // namespace jitkernel -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/jit_kernel_test.cc b/paddle/fluid/operators/math/jit_kernel_test.cc deleted file mode 100644 index 19f7bd89094..00000000000 --- a/paddle/fluid/operators/math/jit_kernel_test.cc +++ /dev/null @@ -1,742 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/math/jit_kernel.h" -#include // for exp -#include // for memcpy -#include -#include -#include -#include "gflags/gflags.h" -#include "glog/logging.h" -#include "gtest/gtest.h" -#include "paddle/fluid/operators/math/jit_kernel_refer.h" -#include "paddle/fluid/platform/port.h" - -#ifdef PADDLE_WITH_MKLML -#include "paddle/fluid/platform/dynload/mklml.h" -#endif - -#ifdef __AVX__ -#include -#endif - -constexpr int repeat = 20000; - -// TODO(TJ): benchmark and test should be seperated, -// benchmark should verify more sizes - -inline double GetCurrentUS() { - struct timeval time; - gettimeofday(&time, NULL); - return 1e+6 * time.tv_sec + time.tv_usec; -} - -template -void RandomVec(const int n, T* a, const T lower = static_cast(-20.f), - const T upper = static_cast(20.f)) { - static unsigned int seed = 100; - std::mt19937 rng(seed++); - std::uniform_real_distribution uniform_dist(0, 1); - for (int i = 0; i < n; ++i) { - a[i] = static_cast(uniform_dist(rng) * (upper - lower) + lower); - } -} - -#if defined __AVX__ || defined __AVX2__ -void vrelu_intri8(const int n, const float* x, float* y) { - __m256 tmp = _mm256_loadu_ps(x); - tmp = _mm256_max_ps(tmp, _mm256_setzero_ps()); - _mm256_storeu_ps(y, tmp); -} -#endif - -TEST(JitKernel, vrelu) { - namespace jit = paddle::operators::math::jitkernel; - namespace refer = paddle::operators::math::jitkernel::refer; - for (int d : {3, 7, 8, 15, 16, 30, 256, 512}) { - std::vector x(d); - std::vector zref(d), ztgt(d); - RandomVec(d, x.data(), -10.f, 1.f); - const auto& ker = - jit::KernelPool::Instance().template Get>(d); - const float* x_data = x.data(); - float* ztgt_data = ztgt.data(); - float* zref_data = zref.data(); - auto trefs = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - refer::VRelu(x_data, zref_data, d); - } - auto trefe = GetCurrentUS(); -#if defined __AVX__ || defined __AVX2__ - if (d == 8) { - auto si0 = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - vrelu_intri8(d, x_data, zref_data); - } - auto si1 = GetCurrentUS(); - VLOG(3) << "Vec size 8 intr takes: " << (si1 - si0) / repeat << " us"; - } -#endif - auto ttgts = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - ker->Compute(x_data, ztgt_data, d); - } - auto ttgte = GetCurrentUS(); - VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat - << " us, tgt takes: " << (ttgte - ttgts) / repeat << " us"; - for (int i = 0; i < d; ++i) { - EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); - } - } -} - -TEST(JitKernel, vaddbias) { - namespace jit = paddle::operators::math::jitkernel; - namespace refer = paddle::operators::math::jitkernel::refer; - for (int d : {7, 8, 15, 16, 30, 64, 100, 128, 256}) { - std::vector x(d); - std::vector zref(d), ztgt(d); - RandomVec(d, x.data(), -2.f, 2.f); - const auto& ker = - jit::KernelPool::Instance().template Get>(d); - const float a = 2.f; - const float* x_data = x.data(); - float* ztgt_data = ztgt.data(); - float* zref_data = zref.data(); - auto trefs = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - refer::VAddBias(&a, x_data, zref_data, d); - } - auto trefe = GetCurrentUS(); - auto ttgts = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - ker->Compute(&a, x_data, ztgt_data, d); - } - auto ttgte = GetCurrentUS(); - - VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat - << " us, tgt takes: " << (ttgte - ttgts) / repeat << " us"; - for (int i = 0; i < d; ++i) { - EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); - } - } -} - -#ifdef PADDLE_WITH_MKLML -void vexp_mkl(const int n, const float* x, float* y) { - paddle::platform::dynload::vsExp(n, x, y); -} -#endif - -TEST(JitKernel, vexp) { - namespace jit = paddle::operators::math::jitkernel; - namespace refer = paddle::operators::math::jitkernel::refer; - for (int d : {1, 3, 4, 6, 7, 8, 12, 15, 16, 20, 30, 128, 256}) { - std::vector x(d); - std::vector zref(d), ztgt(d); - RandomVec(d, x.data(), -2.f, 2.f); - const auto& ker = - jit::KernelPool::Instance().template Get>(d); - const float* x_data = x.data(); - float* ztgt_data = ztgt.data(); - float* zref_data = zref.data(); - auto trefs = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - refer::VExp(x_data, zref_data, d); - } - auto trefe = GetCurrentUS(); - -#ifdef PADDLE_WITH_MKLML - auto tmkls = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - vexp_mkl(d, x_data, zref_data); - } - auto tmkle = GetCurrentUS(); -#endif - - auto ttgts = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - // ker->Compute(x_data, ztgt_data); - ker->Compute(x_data, ztgt_data, d); - } - auto ttgte = GetCurrentUS(); - - VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat -#ifdef PADDLE_WITH_MKLML - << " us, mkl takes: " << (tmkle - tmkls) / repeat << " us, " -#else - << " us, " -#endif - - << "tgt takes: " << (ttgte - ttgts) / repeat << " us"; - for (int i = 0; i < d; ++i) { - EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); - } - } -} - -void vsigmoid_better( - const std::shared_ptr< - const paddle::operators::math::jitkernel::VExpKernel>& vexp, - const int n, const float* x, float* y) { - const float min = SIGMOID_THRESHOLD_MIN; - const float max = SIGMOID_THRESHOLD_MAX; - for (int i = 0; i < n; ++i) { - y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]); - y[i] = 0.f - y[i]; - } - vexp->Compute(y, y, n); - for (int i = 0; i < n; ++i) { - y[i] = 1.f / (1.f + y[i]); - } -} - -TEST(JitKernel, vsigmoid) { - namespace jit = paddle::operators::math::jitkernel; - namespace refer = paddle::operators::math::jitkernel::refer; - for (int d : {1, 3, 4, 6, 7, 8, 15, 16, 30, 32, 64, 100, 128, 256}) { - std::vector x(d); - std::vector zref(d), ztgt(d); - RandomVec(d, x.data(), -2.f, 2.f); - const auto& ker = - jit::KernelPool::Instance().template Get>(d); - const auto& vexp = - jit::KernelPool::Instance().template Get>(d); - const float* x_data = x.data(); - float* ztgt_data = ztgt.data(); - float* zref_data = zref.data(); - auto tmkls = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - vsigmoid_better(vexp, d, x_data, zref_data); - } - auto tmkle = GetCurrentUS(); - auto trefs = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - refer::VSigmoid(x_data, zref_data, d); - } - auto trefe = GetCurrentUS(); - auto ttgts = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - ker->Compute(x_data, ztgt_data, d); - } - auto ttgte = GetCurrentUS(); - - VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat - << " us, better(jit exp) takes: " << (tmkle - tmkls) / repeat - << " us, tgt takes: " << (ttgte - ttgts) / repeat << " us"; - for (int i = 0; i < d; ++i) { - EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); - } - } -} - -void vtanh_better( - const std::shared_ptr< - const paddle::operators::math::jitkernel::VScalKernel>& vscal, - const std::shared_ptr< - const paddle::operators::math::jitkernel::VSigmoidKernel>& - vsigmoid, - const std::shared_ptr< - const paddle::operators::math::jitkernel::VAddBiasKernel>& - vaddbias, - const int n, const float* x, float* y) { - const float a = 2.f, b = -1.f; - vscal->Compute(&a, x, y, n); - vsigmoid->Compute(y, y, n); - vscal->Compute(&a, y, y, n); - vaddbias->Compute(&b, y, y, n); -} - -TEST(JitKernel, vtanh) { - namespace jit = paddle::operators::math::jitkernel; - namespace refer = paddle::operators::math::jitkernel::refer; - for (int d : {1, 2, 3, 4, 5, 6, 7, 8, 15, 16, 30, 32, 64, 100, 128, 256}) { - std::vector x(d); - std::vector zref(d), ztgt(d); - RandomVec(d, x.data(), -2.f, 2.f); - const auto& ker = - jit::KernelPool::Instance().template Get>(d); - const auto& vscal = - jit::KernelPool::Instance().template Get>(d); - const auto& vsigmoid = - jit::KernelPool::Instance().template Get>(d); - const auto& vaddbias = - jit::KernelPool::Instance().template Get>(d); - const float* x_data = x.data(); - float* ztgt_data = ztgt.data(); - float* zref_data = zref.data(); - auto tmkls = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - vtanh_better(vscal, vsigmoid, vaddbias, d, x_data, zref_data); - } - auto tmkle = GetCurrentUS(); - auto trefs = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - refer::VTanh(x_data, zref_data, d); - } - auto trefe = GetCurrentUS(); - auto ttgts = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - ker->Compute(x_data, ztgt_data, d); - } - auto ttgte = GetCurrentUS(); - - VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat - << " us, better(jit exp) takes: " << (tmkle - tmkls) / repeat - << " us, tgt takes: " << (ttgte - ttgts) / repeat << " us"; - for (int i = 0; i < d; ++i) { - EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); - } - } -} - -void lstm_ctht_better( - const std::shared_ptr< - const paddle::operators::math::jitkernel::VSigmoidKernel>& - vsigmoid_3d, - const std::shared_ptr< - const paddle::operators::math::jitkernel::VTanhKernel>& vtanh_d, - const std::shared_ptr< - const paddle::operators::math::jitkernel::VMulKernel>& vmul_d, - const std::shared_ptr< - const paddle::operators::math::jitkernel::VAddKernel>& vadd_d, - const int d, float* gates, const float* ct_1, float* ct, float* ht) { - int d2 = d * 2; - vsigmoid_3d->Compute(gates + d, gates + d, 3 * d); - vtanh_d->Compute(gates, gates, d); - vmul_d->Compute(gates, gates + d, gates + d, d); - vmul_d->Compute(ct_1, gates + d2, gates + d2, d); - vadd_d->Compute(gates + d, gates + d2, ct, d); - /* H_t = act_cell(C_t) * ogated */ - vtanh_d->Compute(ct, gates + d2, d); - vmul_d->Compute(gates + d2, gates + d * 3, ht, d); -} - -TEST(JitKernel, lstm) { - namespace jit = paddle::operators::math::jitkernel; - namespace refer = paddle::operators::math::jitkernel::refer; - for (int d : {1, 2, 3, 4, 5, 6, 7, 8, 15, 16, 30, 32, 64, 100}) { - int d4 = d * 4; - int d3 = d * 3; - std::vector x(d4), xref(d4); - std::vector ct_1(d), ct_tgt(d), ht_tgt(d); - std::vector ct_ref(d), ht_ref(d); - RandomVec(d4, x.data(), -2.f, 2.f); - RandomVec(d, ct_1.data(), -2.f, 2.f); - memcpy(xref.data(), x.data(), sizeof(float) * d4); - std::string act_gate = "sigmoid", act_cand = "tanh", act_cell = "tanh"; - const jit::lstm_attr_t attr(d, act_gate, act_cand, act_cell, false); - const auto& ker = - jit::KernelPool::Instance() - .template Get, const jit::lstm_attr_t&>( - attr); - // below kernels are used to compute refer - const auto& vsigmoid_3d = - jit::KernelPool::Instance().template Get>( - d3); - const auto& vtanh_d = - jit::KernelPool::Instance().template Get>(d); - const auto& vmul_d = - jit::KernelPool::Instance().template Get>(d); - const auto& vadd_d = - jit::KernelPool::Instance().template Get>(d); - - float* x_data = x.data(); - float* xref_data = xref.data(); - const float* ct_1_data = ct_1.data(); - float* ct_tgt_data = ct_tgt.data(); - float* ht_tgt_data = ht_tgt.data(); - float* ct_ref_data = ct_ref.data(); - float* ht_ref_data = ht_ref.data(); - // compute once to check correctness - jit::lstm_t step; - step.gates = xref_data; - step.ct_1 = ct_1_data; - step.ct = ct_ref_data; - step.ht = ht_ref_data; - refer::LSTMCtHt(&step, &attr); - - step.gates = x_data; - step.ct = ct_tgt_data; - step.ht = ht_tgt_data; - ker->ComputeCtHt(&step, &attr); - for (int i = 0; i < d; ++i) { - EXPECT_NEAR(ct_tgt_data[i], ct_ref_data[i], 1e-3); - EXPECT_NEAR(ht_tgt_data[i], ht_ref_data[i], 1e-3); - } - - auto tmkls = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - lstm_ctht_better(vsigmoid_3d, vtanh_d, vmul_d, vadd_d, d, xref_data, - ct_1_data, ct_ref_data, ht_ref_data); - } - auto tmkle = GetCurrentUS(); - auto trefs = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - refer::LSTMCtHt(&step, &attr); - } - auto trefe = GetCurrentUS(); - auto ttgts = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - ker->ComputeCtHt(&step, &attr); - } - auto ttgte = GetCurrentUS(); - VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat - << " us, better(jit) takes: " << (tmkle - tmkls) / repeat - << " us, tgt takes: " << (ttgte - ttgts) / repeat << " us"; - } -} - -#if defined __AVX__ || defined __AVX2__ -void vscal_intri8(const int n, const float a, const float* x, float* y) { - __m256 tmp; - __m256 scalar = _mm256_set1_ps(a); - tmp = _mm256_loadu_ps(x); - tmp = _mm256_mul_ps(tmp, scalar); - _mm256_storeu_ps(y, tmp); -} -void vscal_inp_intri8(const int n, const float a, float* x) { - __m256 tmp; - __m256 scalar = _mm256_set1_ps(a); - tmp = _mm256_loadu_ps(x); - tmp = _mm256_mul_ps(tmp, scalar); - _mm256_storeu_ps(x, tmp); -} -#endif - -#ifdef PADDLE_WITH_MKLML -void vscal_inp_mkl(const int n, const float a, float* x) { - paddle::platform::dynload::cblas_sscal(n, a, x, 1); -} -#endif - -TEST(JitKernel, vscal) { - namespace jit = paddle::operators::math::jitkernel; - namespace refer = paddle::operators::math::jitkernel::refer; - for (int d : {7, 8, 15, 16, 30, 256, 512}) { - std::vector x(d), y(d); - std::vector zref(d), ztgt(d); - RandomVec(d, x.data()); - std::memcpy(y.data(), x.data(), sizeof(float) * d); - float a = 2.f; - const auto& ker = - jit::KernelPool::Instance().template Get>(d); - const float* x_data = x.data(); - float* y_data = y.data(); - float* ztgt_data = ztgt.data(); - float* zref_data = zref.data(); - auto trefs = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - refer::VScal(&a, x_data, zref_data, d); - } - auto trefe = GetCurrentUS(); - auto trefs1 = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - refer::VScal(&a, y_data, y_data, d); - } - auto trefe1 = GetCurrentUS(); - -#ifdef PADDLE_WITH_MKLML - auto tmkls = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - vscal_inp_mkl(d, a, y_data); - } - auto tmkle = GetCurrentUS(); -#endif - -#if defined __AVX__ || defined __AVX2__ - if (d == 8) { - auto si0 = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - vscal_intri8(d, a, x_data, zref_data); - } - auto si1 = GetCurrentUS(); - auto si2 = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - vscal_inp_intri8(d, a, y_data); - } - auto si3 = GetCurrentUS(); - VLOG(3) << "Vec size 8 intr takes: " << (si1 - si0) / repeat - << " us, inplace: " << (si3 - si2) / repeat << " us"; - } -#endif - - auto ttgts = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - ker->Compute(&a, x_data, ztgt_data, d); - } - auto ttgte = GetCurrentUS(); - auto ttgts1 = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - ker->Compute(&a, y_data, y_data, d); - } - auto ttgte1 = GetCurrentUS(); - VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat - << " us, inplace takes: " << (trefe1 - trefs1) / repeat -#ifdef PADDLE_WITH_MKLML - << " us, mkl inplace takes: " << (tmkle - tmkls) / repeat << " us, " -#else - << " us, " -#endif - << "tgt takes: " << (ttgte - ttgts) / repeat - << "us, tgt inplace takes: " << (ttgte1 - ttgts1) / repeat << " us"; - for (int i = 0; i < d; ++i) { - EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); - } - } -} - -#if defined __AVX__ || defined __AVX2__ -void vmul_intri8(const int n, const float* x, const float* y, float* z) { - __m256 tmpx, tmpy; - tmpx = _mm256_loadu_ps(x); - tmpy = _mm256_loadu_ps(y); - tmpx = _mm256_mul_ps(tmpx, tmpy); - _mm256_storeu_ps(z, tmpx); -} -#endif - -#ifdef PADDLE_WITH_MKLML -void vmul_mkl(const int n, const float* x, const float* y, float* z) { - paddle::platform::dynload::vsMul(n, x, y, z); -} -#endif - -TEST(JitKernel, vmul) { - namespace jit = paddle::operators::math::jitkernel; - namespace refer = paddle::operators::math::jitkernel::refer; - for (int d : {7, 8, 15, 16, 20, 30, 256, 512, 1000, 1024}) { - std::vector x(d), y(d); - std::vector zref(d), ztgt(d); - RandomVec(d, x.data()); - RandomVec(d, y.data()); - const auto& ker = - jit::KernelPool::Instance().template Get>(d); - const float* x_data = x.data(); - const float* y_data = y.data(); - float* ztgt_data = ztgt.data(); - float* zref_data = zref.data(); - auto trefs = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - refer::VMul(x_data, y_data, zref_data, d); - } - auto trefe = GetCurrentUS(); - -#ifdef PADDLE_WITH_MKLML - auto tmkls = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - vmul_mkl(d, x_data, y_data, zref_data); - } - auto tmkle = GetCurrentUS(); -#endif - -#if defined __AVX__ || defined __AVX2__ - if (d == 8) { - auto si0 = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - vmul_intri8(d, x_data, y_data, zref_data); - } - auto si1 = GetCurrentUS(); - VLOG(3) << "Vec size 8 intr takes: " << (si1 - si0) / repeat; - } -#endif - - auto ttgts = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - ker->Compute(x_data, y_data, ztgt_data, d); - } - auto ttgte = GetCurrentUS(); - - VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat -#ifdef PADDLE_WITH_MKLML - << " us, mkl takes: " << (tmkle - tmkls) / repeat << " us, " -#else - << " us, " -#endif - << "tgt takes: " << (ttgte - ttgts) / repeat << " us"; - for (int i = 0; i < d; ++i) { - EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); - } - } -} - -#if defined __AVX__ || defined __AVX2__ -void vadd_intri8(const int n, const float* x, const float* y, float* z) { - __m256 tmpx, tmpy; - tmpx = _mm256_loadu_ps(x); - tmpy = _mm256_loadu_ps(y); - tmpx = _mm256_add_ps(tmpx, tmpy); - _mm256_storeu_ps(z, tmpx); -} -#endif - -#ifdef PADDLE_WITH_MKLML -void vadd_mkl(const int n, const float* x, const float* y, float* z) { - paddle::platform::dynload::vsAdd(n, x, y, z); -} -#endif - -TEST(JitKernel, vadd) { - namespace jit = paddle::operators::math::jitkernel; - namespace refer = paddle::operators::math::jitkernel::refer; - for (int d : {7, 8, 15, 16, 30, 256, 512}) { - std::vector x(d), y(d); - std::vector zref(d), ztgt(d); - RandomVec(d, x.data()); - RandomVec(d, y.data()); - const auto& ker = - jit::KernelPool::Instance().template Get>(d); - const float* x_data = x.data(); - const float* y_data = y.data(); - float* ztgt_data = ztgt.data(); - float* zref_data = zref.data(); - auto trefs = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - refer::VAdd(x_data, y_data, zref_data, d); - } - auto trefe = GetCurrentUS(); - -#ifdef PADDLE_WITH_MKLML - auto tmkls = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - vadd_mkl(d, x_data, y_data, zref_data); - } - auto tmkle = GetCurrentUS(); -#endif - -#if defined __AVX__ || defined __AVX2__ - if (d == 8) { - auto si0 = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - vadd_intri8(d, x_data, y_data, zref_data); - } - auto si1 = GetCurrentUS(); - VLOG(3) << "Vec size 8 intr takes: " << (si1 - si0) / repeat; - } -#endif - - auto ttgts = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - ker->Compute(x_data, y_data, ztgt_data, d); - } - auto ttgte = GetCurrentUS(); - - VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat -#ifdef PADDLE_WITH_MKLML - << " us, mkl takes: " << (tmkle - tmkls) / repeat << " us, " -#else - << " us, " -#endif - << "tgt takes: " << (ttgte - ttgts) / repeat << " us"; - for (int i = 0; i < d; ++i) { - EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); - } - } -} - -void vaddrelu_better( - const std::shared_ptr< - const paddle::operators::math::jitkernel::VAddKernel>& vadd, - const std::shared_ptr< - const paddle::operators::math::jitkernel::VReluKernel>& vrelu, - const float* x, const float* y, float* z, int d) { - vadd->Compute(x, y, z, d); - vrelu->Compute(z, z, d); -} - -TEST(JitKernel, vaddrelu) { - namespace jit = paddle::operators::math::jitkernel; - namespace refer = paddle::operators::math::jitkernel::refer; - for (int d : {7, 8, 15, 16, 30, 256, 512}) { - std::vector x(d), y(d); - std::vector zref(d), ztgt(d); - RandomVec(d, x.data()); - RandomVec(d, y.data()); - const auto& ker = - jit::KernelPool::Instance().template Get>(d); - const auto& vadd = - jit::KernelPool::Instance().template Get>(d); - const auto& vrelu = - jit::KernelPool::Instance().template Get>(d); - const float* x_data = x.data(); - const float* y_data = y.data(); - float* ztgt_data = ztgt.data(); - float* zref_data = zref.data(); - auto trefs = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - refer::VAddRelu(x_data, y_data, zref_data, d); - } - auto trefe = GetCurrentUS(); - auto tmkls = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - vaddrelu_better(vadd, vrelu, x_data, y_data, zref_data, d); - } - auto tmkle = GetCurrentUS(); - auto ttgts = GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { - ker->Compute(x_data, y_data, ztgt_data, d); - } - auto ttgte = GetCurrentUS(); - VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat - << " us, better takes: " << (tmkle - tmkls) / repeat << " us, " - << "tgt takes: " << (ttgte - ttgts) / repeat << " us"; - for (int i = 0; i < d; ++i) { - EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); - } - } -} - -TEST(JitKernel, pool) { - namespace jit = paddle::operators::math::jitkernel; - const int frame_size = 4; - std::string act_gate = "sigmoid", act_cand = "tanh", act_cell = "tanh"; - jit::lstm_attr_t attr(frame_size, act_gate, act_cand, act_cell, false); - - // empty call it to avoid unknown flag 'use_pinned_memory' on Mac - paddle::platform::MayIUse(paddle::platform::avx); - const auto& plstm1 = - jit::KernelPool::Instance() - .template Get, const jit::lstm_attr_t&>(attr); - - const auto& plstm2 = - jit::KernelPool::Instance() - .template Get, const jit::lstm_attr_t&>(attr); - EXPECT_EQ(plstm1, plstm2); - - const auto& peephole = - jit::KernelPool::Instance() - .template Get, const jit::lstm_attr_t&>( - jit::lstm_attr_t(frame_size, act_gate, act_cand, act_cell, true)); - EXPECT_TRUE(plstm1 != peephole); - - const auto& pvmul_f = - jit::KernelPool::Instance().template Get>(4); - EXPECT_TRUE(std::dynamic_pointer_cast(plstm2) != - std::dynamic_pointer_cast(pvmul_f)); - - const auto& pvmul_d = - jit::KernelPool::Instance().template Get>(4); - EXPECT_TRUE(std::dynamic_pointer_cast(pvmul_f) != - std::dynamic_pointer_cast(pvmul_d)); - - const auto& pvmul_from_key = jit::KernelPool::Instance().Get("vmulfjit4"); -#if defined(__APPLE__) || defined(__OSX__) || defined(_WIN32) - EXPECT_EQ(pvmul_from_key, nullptr); -#else - EXPECT_EQ(pvmul_from_key, pvmul_f); -#endif - const auto& pvmul_from_key2 = jit::KernelPool::Instance().Get("vmulfjit"); - EXPECT_TRUE(pvmul_from_key2 == nullptr); -} -- GitLab