/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include #include // for shared_ptr #include #include #include "paddle/fluid/operators/math/jit_kernel_impl.h" #include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/macros.h" // Note: Only support on CPU yet. namespace paddle { namespace operators { namespace math { namespace jitkernel { // TODO(TJ): remove me typedef enum { kLT8, kEQ8, kGT8LT16, kEQ16, kGT16 } jit_block; class Kernel { public: Kernel() = default; virtual ~Kernel() = default; // TODO(TJ): below members should be deprecated. int num_{0}; int end_{0}; int rest_{0}; DISABLE_COPY_AND_ASSIGN(Kernel); }; class KernelPool { public: static KernelPool &Instance(); template std::shared_ptr Get(ARGS... args); std::shared_ptr Get(const std::string &key) const; private: KernelPool() = default; std::unordered_map> kers_; DISABLE_COPY_AND_ASSIGN(KernelPool); }; template class VMulKernel : public Kernel { public: void (*Compute)(const T *, const T *, T *, int); }; template class VAddKernel : public Kernel { public: void (*Compute)(const T *, const T *, T *, int); }; template class VAddReluKernel : public Kernel { public: void (*Compute)(const T *, const T *, T *, int); }; template class VScalKernel : public Kernel { public: // y = a.*x void (*Compute)(const T *, const T *, T *, int); }; template class VAddBiasKernel : public Kernel { public: // y = a.+x void (*Compute)(const T *, const T *, T *, int); }; template class VActKernel : public Kernel { public: void (*Compute)(const T *, T *, int); }; template class VReluKernel : public VActKernel {}; template class VIdentityKernel : public VActKernel {}; template class VExpKernel : public VActKernel {}; template class VSigmoidKernel : public VActKernel {}; template class VTanhKernel : public VActKernel {}; template class LSTMKernel : public Kernel { public: virtual void ComputeCtHt(T *gates, const T *ct_1, T *ct, T *ht, /* below only used in peephole*/ const T *wp_data = nullptr, T *checked = nullptr) const = 0; virtual void ComputeC1H1(T *gates, T *ct, T *ht, /* below only used in peephole*/ const T *wp_data = nullptr) const = 0; // void (*ComputeCtHt)(lstm_t *); // // compute c1 and h1 without c0 or h0 // void (*ComputeC1H1)(lstm_t *); }; template class GRUKernel : public Kernel { public: // compute h1 without h0 virtual void ComputeH1(T *gates, T *ht) const = 0; virtual void ComputeHtPart1(T *gates, const T *ht_1, T *ht) const = 0; virtual void ComputeHtPart2(T *gates, const T *ht_1, T *ht) const = 0; }; template class CRFDecodeKernel : public Kernel { public: virtual void Compute(const int seq_len, const T *x, const T *w, T *alpha, int *track) const = 0; }; template class LayerNormKernel : public Kernel { public: virtual void Compute(T *x, T *out, T *mean, T *var, const T *scale, const T *bias, int height, const float epsilon) const = 0; }; } // namespace jitkernel } // namespace math } // namespace operators } // namespace paddle