提交 7ef2699e 编写于 作者: T tensor-tang

init peephole runtime kernel

上级 3ee8f2c6
...@@ -400,10 +400,9 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -400,10 +400,9 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
} else { } else {
const auto& ker = const auto& ker =
math::jitkernel::KernelPool::Instance() math::jitkernel::KernelPool::Instance()
.template Get<math::jitkernel::LSTMKernel<T>, int, .template Get<math::jitkernel::LSTMKernel<T>, const std::string&,
const std::string&, const std::string&, const std::string&, const std::string&>(
const std::string&>(D, act_gate_str, act_cand_str, act_gate_str, act_cand_str, act_cell_str, D, false);
act_cell_str);
for (int i = 0; i < N; ++i) { for (int i = 0; i < N; ++i) {
PROCESS_H0C0 PROCESS_H0C0
...@@ -545,10 +544,9 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -545,10 +544,9 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
} else { } else {
const auto& ker = const auto& ker =
math::jitkernel::KernelPool::Instance() math::jitkernel::KernelPool::Instance()
.template Get<math::jitkernel::LSTMKernel<T>, int, .template Get<math::jitkernel::LSTMKernel<T>, const std::string&,
const std::string&, const std::string&, const std::string&, const std::string&>(
const std::string&>(D, act_gate_str, act_cand_str, act_gate_str, act_cand_str, act_cell_str, D, false);
act_cell_str);
for (int step = tstart; step < max_seq_len; ++step) { for (int step = tstart; step < max_seq_len; ++step) {
const int cur_bs = batch_starts[step + 1] - batch_starts[step]; const int cur_bs = batch_starts[step + 1] - batch_starts[step];
......
...@@ -125,7 +125,8 @@ class VTanhKernel : public VActKernel<T> { ...@@ -125,7 +125,8 @@ class VTanhKernel : public VActKernel<T> {
template <typename T> template <typename T>
class LSTMKernel : public Kernel { class LSTMKernel : public Kernel {
public: public:
virtual void ComputeCtHt(T *gates, const T *ct_1, T *ct, T *ht) const = 0; virtual void ComputeCtHt(T *gates, const T *ct_1, T *ct, T *ht,
T *checked = nullptr) const = 0;
}; };
} // namespace jitkernel } // namespace jitkernel
......
...@@ -86,9 +86,9 @@ __m256 AVXActImpl<kIdentity>::Compute(__m256 x) const { ...@@ -86,9 +86,9 @@ __m256 AVXActImpl<kIdentity>::Compute(__m256 x) const {
template <typename T, jit::cpu_isa_t isa, jit_block> template <typename T, jit::cpu_isa_t isa, jit_block>
class LSTMKernelImpl : public LSTMKernel<T> { class LSTMKernelImpl : public LSTMKernel<T> {
public: public:
explicit LSTMKernelImpl(int d, const std::string& act_gate, explicit LSTMKernelImpl(const std::string& act_gate,
const std::string& act_cand, const std::string& act_cand,
const std::string& act_cell) const std::string& act_cell, int d)
: LSTMKernel<T>() { : LSTMKernel<T>() {
d_ = d; d_ = d;
d2_ = d * 2; d2_ = d * 2;
...@@ -134,7 +134,8 @@ class LSTMKernelImpl : public LSTMKernel<T> { ...@@ -134,7 +134,8 @@ class LSTMKernelImpl : public LSTMKernel<T> {
#endif #endif
} }
void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht) const override { void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht,
T* checked) const override {
// gates: W_ch, W_ih, W_fh, W_oh // gates: W_ch, W_ih, W_fh, W_oh
act_gate_3d_->Compute(gates + d_, gates + d_); act_gate_3d_->Compute(gates + d_, gates + d_);
...@@ -162,7 +163,8 @@ class LSTMKernelImpl : public LSTMKernel<T> { ...@@ -162,7 +163,8 @@ class LSTMKernelImpl : public LSTMKernel<T> {
#define INTRI8_FLOAT(isa) \ #define INTRI8_FLOAT(isa) \
template <> \ template <> \
void LSTMKernelImpl<float, isa, kEQ8>::ComputeCtHt( \ void LSTMKernelImpl<float, isa, kEQ8>::ComputeCtHt( \
float* gates, const float* ct_1, float* ct, float* ht) const { \ float* gates, const float* ct_1, float* ct, float* ht, float* checked) \
const { \
/* gates: W_ch, W_ih, W_fh, W_oh */ \ /* gates: W_ch, W_ih, W_fh, W_oh */ \
__m256 c, i, f, o; \ __m256 c, i, f, o; \
c = _mm256_loadu_ps(gates); \ c = _mm256_loadu_ps(gates); \
...@@ -192,21 +194,86 @@ INTRI8_FLOAT(jit::avx2); ...@@ -192,21 +194,86 @@ INTRI8_FLOAT(jit::avx2);
INTRI8_FLOAT(jit::avx512f); INTRI8_FLOAT(jit::avx512f);
#endif #endif
/* Peephole JitKernel */
template <typename T, jit::cpu_isa_t isa, jit_block>
class PeepholeKernelImpl : public LSTMKernel<T> {
public:
explicit PeepholeKernelImpl(const std::string& act_gate,
const std::string& act_cand,
const std::string& act_cell, int d)
: LSTMKernel<T>() {
d_ = d;
d2_ = d * 2;
d3_ = d * 3;
auto GetActKernel = [&](const std::string& type,
int n) -> std::shared_ptr<const VActKernel<T>> {
if (type == "sigmoid") {
return std::dynamic_pointer_cast<const VActKernel<T>>(
KernelPool::Instance().template Get<VSigmoidKernel<T>>(n));
} else if (type == "relu") {
return std::dynamic_pointer_cast<const VActKernel<T>>(
KernelPool::Instance().template Get<VReluKernel<T>>(n));
} else if (type == "tanh") {
return std::dynamic_pointer_cast<const VActKernel<T>>(
KernelPool::Instance().template Get<VTanhKernel<T>>(n));
} else if (type == "identity" || type == "") {
return std::dynamic_pointer_cast<const VActKernel<T>>(
KernelPool::Instance().template Get<VIdentityKernel<T>>(n));
}
PADDLE_THROW("Not support type: %s", type);
};
act_gate_3d_ = GetActKernel(act_gate, d * 3);
act_cand_d_ = GetActKernel(act_cand, d);
act_cell_d_ = GetActKernel(act_cell, d);
vmul_d_ = KernelPool::Instance().template Get<VMulKernel<T>>(d);
vadd_d_ = KernelPool::Instance().template Get<VAddKernel<T>>(d);
}
void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht,
T* checked) const override {
// gates: W_ch, W_ih, W_fh, W_oh
act_gate_3d_->Compute(gates + d_, gates + d_);
/* C_t = C_t-1 * fgated + cand_gated * igated */
act_cand_d_->Compute(gates, gates);
vmul_d_->Compute(gates, gates + d_, gates + d_);
vmul_d_->Compute(ct_1, gates + d2_, gates + d2_);
vadd_d_->Compute(gates + d_, gates + d2_, ct);
/* H_t = act_cell(C_t) * ogated */
act_cell_d_->Compute(ct, gates + d2_);
vmul_d_->Compute(gates + d2_, gates + d3_, ht);
}
private:
int d_, d2_, d3_;
std::shared_ptr<const VActKernel<T>> act_gate_3d_, act_cand_d_, act_cell_d_;
std::shared_ptr<const VMulKernel<T>> vmul_d_;
std::shared_ptr<const VAddKernel<T>> vadd_d_;
};
#define JITKERNEL_DECLARE_LSTM(ker_class, ker_dtype) \ #define JITKERNEL_DECLARE_LSTM(ker_class, ker_dtype) \
template <> \ template <> \
std::shared_ptr<const ker_class<ker_dtype>> \ std::shared_ptr<const LSTMKernel<ker_dtype>> \
KernelPool::Get<ker_class<ker_dtype>, int, const std::string&, \ KernelPool::Get<LSTMKernel<ker_dtype>, const std::string&, \
const std::string&, const std::string&>( \ const std::string&, const std::string&, int, bool>( \
int d, const std::string& act_gate, const std::string& act_cand, \ const std::string& act_gate, const std::string& act_cand, \
const std::string& act_cell) const std::string& act_cell, int d, bool use_peephole)
#define JITKERNEL_KEY_LSTM(ker_key, dtype_key) \ #define JITKERNEL_KEY_LSTM(ker_key, dtype_key) \
#ker_key #dtype_key + std::to_string(d) + act_gate + act_cand + act_cell #ker_key #dtype_key + std::to_string(d) + act_gate + act_cand + act_cell + \
(use_peephole ? "p" : "n")
#define JITKERNEL_NEW_LSTM_IMPL(ker, dtype, isa, k) \ #define JITKERNEL_NEW_LSTM_IMPL(ker, dtype, isa, k) \
if (use_peephole) { \
p = std::dynamic_pointer_cast<ker<dtype>>( \ p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<ker##Impl<dtype, isa, k>>(d, act_gate, act_cand, \ std::make_shared<PeepholeKernelImpl<dtype, isa, k>>( \
act_cell)) act_gate, act_cand, act_cell, d)); \
} else { \
p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<ker##Impl<dtype, isa, k>>(act_gate, act_cand, \
act_cell, d)); \
}
REGISTER_JITKERNEL_ARGS(lstm, LSTMKernel, JITKERNEL_DECLARE_LSTM, REGISTER_JITKERNEL_ARGS(lstm, LSTMKernel, JITKERNEL_DECLARE_LSTM,
JITKERNEL_KEY_LSTM, JITKERNEL_NEW_LSTM_IMPL); JITKERNEL_KEY_LSTM, JITKERNEL_NEW_LSTM_IMPL);
...@@ -215,7 +282,6 @@ REGISTER_JITKERNEL_ARGS(lstm, LSTMKernel, JITKERNEL_DECLARE_LSTM, ...@@ -215,7 +282,6 @@ REGISTER_JITKERNEL_ARGS(lstm, LSTMKernel, JITKERNEL_DECLARE_LSTM,
#undef JITKERNEL_DECLARE_LSTM #undef JITKERNEL_DECLARE_LSTM
#undef JITKERNEL_KEY_LSTM #undef JITKERNEL_KEY_LSTM
#undef JITKERNEL_NEW_LSTM_IMPL #undef JITKERNEL_NEW_LSTM_IMPL
} // namespace jitkernel } // namespace jitkernel
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -390,9 +390,9 @@ TEST(JitKernel, lstm) { ...@@ -390,9 +390,9 @@ TEST(JitKernel, lstm) {
std::string act_gate = "sigmoid", act_cand = "tanh", act_cell = "tanh"; std::string act_gate = "sigmoid", act_cand = "tanh", act_cell = "tanh";
const auto& ker = const auto& ker =
jit::KernelPool::Instance() jit::KernelPool::Instance()
.template Get<jit::LSTMKernel<float>, int, const std::string&, .template Get<jit::LSTMKernel<float>, const std::string&,
const std::string&, const std::string&>( const std::string&, const std::string&>(
d, act_gate, act_cand, act_cell); act_gate, act_cand, act_cell, d, false);
// below kernels are used to compute refer // below kernels are used to compute refer
const auto& vsigmoid_3d = const auto& vsigmoid_3d =
jit::KernelPool::Instance().template Get<jit::VSigmoidKernel<float>>( jit::KernelPool::Instance().template Get<jit::VSigmoidKernel<float>>(
...@@ -717,15 +717,20 @@ TEST(JitKernel, pool) { ...@@ -717,15 +717,20 @@ TEST(JitKernel, pool) {
std::string act_gate = "sigmoid", act_cand = "tanh", act_cell = "tanh"; std::string act_gate = "sigmoid", act_cand = "tanh", act_cell = "tanh";
const auto& plstm1 = const auto& plstm1 =
jit::KernelPool::Instance() jit::KernelPool::Instance()
.template Get<jit::LSTMKernel<float>, int, const std::string&, .template Get<jit::LSTMKernel<float>, const std::string&,
const std::string&, const std::string&>( const std::string&, const std::string&>(
frame_size, act_gate, act_cand, act_cell); act_gate, act_cand, act_cell, frame_size, false);
const auto& plstm2 = const auto& plstm2 =
jit::KernelPool::Instance() jit::KernelPool::Instance()
.template Get<jit::LSTMKernel<float>, int, const std::string&, .template Get<jit::LSTMKernel<float>, const std::string&,
const std::string&, const std::string&>( const std::string&, const std::string&>(
frame_size, act_gate, act_cand, act_cell); act_gate, act_cand, act_cell, frame_size, false);
EXPECT_EQ(plstm1, plstm2); const auto& peephole =
jit::KernelPool::Instance()
.template Get<jit::LSTMKernel<float>, const std::string&,
const std::string&, const std::string&>(
act_gate, act_cand, act_cell, frame_size, true);
EXPECT_TRUE(plstm1 != peephole);
const auto& pvmul_f = const auto& pvmul_f =
jit::KernelPool::Instance().template Get<jit::VMulKernel<float>>(4); jit::KernelPool::Instance().template Get<jit::VMulKernel<float>>(4);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册