提交 92031968 编写于 作者: T tensor-tang

init vmul kernel

上级 b9acbcc8
...@@ -16,23 +16,132 @@ limitations under the License. */ ...@@ -16,23 +16,132 @@ limitations under the License. */
#include <functional> #include <functional>
#include <string> #include <string>
#include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/platform/cpu_info.h"
#ifdef PADDLE_WITH_MKLML
#include "paddle/fluid/platform/dynload/mklml.h"
#endif
#ifdef __AVX__
#include <immintrin.h>
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
namespace jitkernel { namespace jitkernel {
namespace jit = platform::jit;
KernelPool& KernelPool::Instance() { KernelPool& KernelPool::Instance() {
static KernelPool g_jit_kernels; static KernelPool g_jit_kernels;
return g_jit_kernels; return g_jit_kernels;
} }
#define SEARCH_BLOCK(src, t, isa) \
if (d < AVX_FLOAT_BLOCK) { \
Compute = src<t, isa, kLT8>; \
} else if (d == AVX_FLOAT_BLOCK) { \
Compute = src<t, isa, kEQ8>; \
} else if (d == AVX512_FLOAT_BLOCK) { \
Compute = src<t, isa, kEQ16>; \
} else { \
Compute = src<t, isa, kGT16>; \
}
#define SEARCH_ISA_BLOCK(src, t) \
if (jit::MayIUse(jit::avx512_common)) { \
SEARCH_BLOCK(src, t, jit::avx512_common); \
} else if (jit::MayIUse(jit::avx2)) { \
SEARCH_BLOCK(src, t, jit::avx2); \
} else if (jit::MayIUse(jit::avx)) { \
SEARCH_BLOCK(src, t, jit::avx); \
} else { \
SEARCH_BLOCK(src, t, jit::isa_any); \
}
#define FOR_EACH_BLOCK(macro_, isa) \
macro_(isa, kLT8) macro_(isa, kEQ8) macro_(isa, kEQ16) macro_(isa, kGT16)
#define FOR_EACH_ISA_BLOCK(macro_) \
FOR_EACH_BLOCK(macro_, jit::avx512_common) \
FOR_EACH_BLOCK(macro_, jit::avx2) \
FOR_EACH_BLOCK(macro_, jit::avx) \
FOR_EACH_BLOCK(macro_, jit::any)
#define VMUL_ANY \
for (int i = 0; i < n; ++i) { \
z[i] = x[i] * y[i]; \
}
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
static void VMulCompute(const int n, const T* x, const T* y, T* z) {
VMUL_ANY
}
#ifdef PADDLE_USE_MKLML
#define DEFINE_VMUL_COMPUTE_FLOAT(isa, block) \
template <> \
static void VMulCompute<float, isa, block>(const int n, const float* x, \
const float* y, float* z) { \
platform::dynload::vsMul(n, x, y, z); \
}
#define DEFINE_VMUL_COMPUTE_DOUBLE(isa, block) \
template <> \
static void VMulCompute<double, isa, block>(const int n, const double* x, \
const double* y, float* z) { \
platform::dynload::vdMul(n, x, y, z); \
}
FOR_EACH_ISA_BLOCK(DEFINE_VMUL_COMPUTE_FLOAT)
FOR_EACH_ISA_BLOCK(DEFINE_VMUL_COMPUTE_DOUBLE)
// TODO(TJ): add EQ8
#endif
#undef DEFINE_VMUL_COMPUTE_FLOAT
#undef DEFINE_VMUL_COMPUTE_DOUBLE
#undef VMUL_ANY
template <>
VMulKernel<float>::VMulKernel(int d) {
SEARCH_ISA_BLOCK(VMulCompute, float);
}
template <>
VMulKernel<double>::VMulKernel(int d) {
SEARCH_ISA_BLOCK(VMulCompute, double);
}
template <>
const std::shared_ptr<VMulKernel<float>> KernelPool::Get<VMulKernel<float>>(
int d) {
std::string key = "f" + std::to_string(d);
if (kers_.find(key) == kers_.end()) {
auto p = std::make_shared<VMulKernel<float>>(d);
kers_.insert({key, std::dynamic_pointer_cast<Kernel>(p)});
return p;
}
return std::dynamic_pointer_cast<VMulKernel<float>>(kers_.at(key));
}
template <>
const std::shared_ptr<VMulKernel<double>> KernelPool::Get<VMulKernel<double>>(
int d) {
std::string key = "d" + std::to_string(d);
if (kers_.find(key) == kers_.end()) {
auto p = std::make_shared<VMulKernel<double>>(d);
kers_.insert({key, std::dynamic_pointer_cast<Kernel>(p)});
return p;
}
return std::dynamic_pointer_cast<VMulKernel<double>>(kers_.at(key));
}
template <> template <>
LSTMKernel<float>::LSTMKernel(int d, const std::string& act_gate_str, LSTMKernel<float>::LSTMKernel(int d, const std::string& act_gate_str,
const std::string& act_cand_str, const std::string& act_cand_str,
const std::string& act_cell_str) const std::string& act_cell_str)
: Kernel(), d_(d) { : Kernel(), d_(d) {
d2_ = d * 2;
d3_ = d * 3;
if (platform::jit::MayIUse(platform::jit::avx512_common)) { if (platform::jit::MayIUse(platform::jit::avx512_common)) {
math::VecActivations<float, platform::jit::avx512_common> act_functor; math::VecActivations<float, platform::jit::avx512_common> act_functor;
act_gate_ = act_functor(act_gate_str); act_gate_ = act_functor(act_gate_str);
...@@ -48,6 +157,22 @@ LSTMKernel<float>::LSTMKernel(int d, const std::string& act_gate_str, ...@@ -48,6 +157,22 @@ LSTMKernel<float>::LSTMKernel(int d, const std::string& act_gate_str,
act_gate_ = act_functor(act_gate_str); act_gate_ = act_functor(act_gate_str);
act_cell_ = act_functor(act_cell_str); act_cell_ = act_functor(act_cell_str);
act_cand_ = act_functor(act_cand_str); act_cand_ = act_functor(act_cand_str);
// ComputeCtHt = [&](float*gates,const float*ct_1,float*ct, float*ht) {
// // gates: W_ch, W_ih, W_fh, W_oh
// act_gate(d3_, gates + d_, gates + d_);
// /* C_t = C_t-1 * fgated + cand_gated * igated */
// act_cand(d_, gates, gates);
// blas.VMUL(d_, gates, gates + d_, gates + d_);
// blas.VMUL(d_, ct_1, gates + d2_, gates + d2_);
// blas.VADD(d_, gates + d_, gates + d2_, ct);
// /* H_t = act_cell(C_t) * ogated */
// act_cell(d_, ct, gates + d2_);
// blas.VMUL(d_, gates + d2_, gates + d3_, ht)
// GET_Ct(ct_1, gates, ct);
// GET_Ht(ct, gates, ht);
// };
} else { } else {
math::VecActivations<float, platform::jit::isa_any> act_functor; math::VecActivations<float, platform::jit::isa_any> act_functor;
act_gate_ = act_functor(act_gate_str); act_gate_ = act_functor(act_gate_str);
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <memory> // for shared_ptr #include <memory> // for shared_ptr
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/macros.h"
// Note: Only support on CPU yet. // Note: Only support on CPU yet.
...@@ -25,6 +26,18 @@ namespace operators { ...@@ -25,6 +26,18 @@ namespace operators {
namespace math { namespace math {
namespace jitkernel { namespace jitkernel {
#define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0
#define AVX_FLOAT_BLOCK 8
#define AVX_DOUBLE_BLOCK 4
#define AVX2_FLOAT_BLOCK 8
#define AVX2_DOUBLE_BLOCK 4
#define AVX512_FLOAT_BLOCK 16
#define AVX512_DOUBLE_BLOCK 8
typedef enum { kLT8, kEQ8, kEQ16, kGT16 } jit_block;
class Kernel { class Kernel {
public: public:
Kernel() {} Kernel() {}
...@@ -36,7 +49,7 @@ class Kernel { ...@@ -36,7 +49,7 @@ class Kernel {
class KernelPool { class KernelPool {
public: public:
static KernelPool& Instance(); static KernelPool &Instance();
template <typename Ker, typename... ARGS> template <typename Ker, typename... ARGS>
const std::shared_ptr<Ker> Get(ARGS... args); const std::shared_ptr<Ker> Get(ARGS... args);
...@@ -48,17 +61,24 @@ class KernelPool { ...@@ -48,17 +61,24 @@ class KernelPool {
DISABLE_COPY_AND_ASSIGN(KernelPool); DISABLE_COPY_AND_ASSIGN(KernelPool);
}; };
template <typename T>
class VMulKernel : public Kernel {
public:
explicit VMulKernel(int n);
void (*Compute)(const int n, const T *, const T *, T *);
};
template <typename T> template <typename T>
class LSTMKernel : public Kernel { class LSTMKernel : public Kernel {
public: public:
explicit LSTMKernel(int d, const std::string& act_gate, explicit LSTMKernel(int d, const std::string &act_gate,
const std::string& act_cand, const std::string& act_cell); const std::string &act_cand, const std::string &act_cell);
void ComputeCtHt(T* gates, const T* ct_1, T* ct); void (*jit_ker)(T *, const T *, T *, T *);
void ComputeCtHt_NoC0H0(T* gates, const T* ct_1, T* ct); std::function<void(T *, const T *, T *, T *)> ComputeCtHt, ComputeCtHt_NoC0H0;
private: private:
int d_; int d_, d2_, d3_;
std::function<void(const int, const T *, T *)> act_gate_, act_cell_, std::function<void(const int, const T *, T *)> act_gate_, act_cell_,
act_cand_; act_cand_;
}; };
......
...@@ -23,10 +23,25 @@ TEST(JitKernel, pool) { ...@@ -23,10 +23,25 @@ TEST(JitKernel, pool) {
namespace jit = paddle::operators::math::jitkernel; namespace jit = paddle::operators::math::jitkernel;
const int frame_size = 4; const int frame_size = 4;
std::string act_gate = "sigmoid", act_cand = "tanh", act_cell = "tanh"; std::string act_gate = "sigmoid", act_cand = "tanh", act_cell = "tanh";
const auto& t = const auto& p1 =
jit::KernelPool::Instance() jit::KernelPool::Instance()
.template Get<jit::LSTMKernel<float>, int, const std::string&, .template Get<jit::LSTMKernel<float>, int, const std::string&,
const std::string&, const std::string&>( const std::string&, const std::string&>(
frame_size, act_gate, act_cand, act_cell); frame_size, act_gate, act_cand, act_cell);
LOG(INFO) << t; const auto& p2 =
jit::KernelPool::Instance()
.template Get<jit::LSTMKernel<float>, int, const std::string&,
const std::string&, const std::string&>(
frame_size, act_gate, act_cand, act_cell);
EXPECT_EQ(p1, p2);
const auto& p3 =
jit::KernelPool::Instance().template Get<jit::VMulKernel<float>>(4);
EXPECT_TRUE(std::dynamic_pointer_cast<jit::Kernel>(p2) !=
std::dynamic_pointer_cast<jit::Kernel>(p3));
const auto& p4 =
jit::KernelPool::Instance().template Get<jit::VMulKernel<double>>(4);
EXPECT_TRUE(std::dynamic_pointer_cast<jit::Kernel>(p3) !=
std::dynamic_pointer_cast<jit::Kernel>(p4));
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册