未验证 提交 9cb8738f 编写于 作者: T tensor-tang 提交者: GitHub

Merge pull request #14018 from tensor-tang/refine/jit/gru

Refine/jit/gru
...@@ -16,10 +16,9 @@ limitations under the License. */ ...@@ -16,10 +16,9 @@ limitations under the License. */
#include <cstring> // for memcpy #include <cstring> // for memcpy
#include <string> #include <string>
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/operators/math/fc_compute.h" #include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/jit_kernel.h"
#include "paddle/fluid/operators/math/sequence2batch.h" #include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -174,58 +173,44 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -174,58 +173,44 @@ class FusionGRUKernel : public framework::OpKernel<T> {
} }
} }
#define INIT_VEC_FUNC \ #define INIT_BASE_DEFINES \
std::function<void(const int, const T *, T *)> act_gate, act_state; \ auto* x = ctx.Input<LoDTensor>("X"); \
std::function<void(const int, const T*, const T*, const T*, T*)> cross; \ auto* wh = ctx.Input<Tensor>("WeightH"); \
auto& act_gate_str = ctx.Attr<std::string>("gate_activation"); \ auto* xx = ctx.Output<LoDTensor>("XX"); \
auto& act_state_str = ctx.Attr<std::string>("activation"); \ auto x_lod = x->lod(); \
if (platform::jit::MayIUse(platform::jit::avx)) { \ auto x_dims = x->dims(); /* T x M*/ \
math::VecActivations<T, platform::jit::avx> act_functor; \ auto wh_dims = wh->dims(); /* D x 3D*/ \
act_gate = act_functor(act_gate_str); \ const int total_T = x_dims[0]; \
act_state = act_functor(act_state_str); \ const int D3 = wh_dims[1]
cross = math::vec_cross<T, platform::jit::avx>; \
} else { \ #define INIT_OTHER_DEFINES \
math::VecActivations<T, platform::jit::isa_any> act_functor; \ auto* h0 = ctx.Input<Tensor>("H0"); \
act_gate = act_functor(act_gate_str); \ auto* wx = ctx.Input<Tensor>("WeightX"); \
act_state = act_functor(act_state_str); \ auto* bias = ctx.Input<Tensor>("Bias"); \
cross = math::vec_cross<T, platform::jit::isa_any>; \ auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
} bool is_reverse = ctx.Attr<bool>("is_reverse"); \
const int M = x_dims[1]; \
#define INIT_BASE_INPUT_OUTPUT \ const int D = wh_dims[0]; \
auto* h0 = ctx.Input<Tensor>("H0"); \ const int D2 = D * 2; \
auto* wx = ctx.Input<Tensor>("WeightX"); \ const auto& ker = math::jitkernel::KernelPool::Instance() \
auto* wh = ctx.Input<Tensor>("WeightH"); \ .template Get<math::jitkernel::GRUKernel<T>, \
auto* bias = ctx.Input<Tensor>("Bias"); \ const std::string&, const std::string&>( \
auto* xx = ctx.Output<LoDTensor>("XX"); \ ctx.Attr<std::string>("gate_activation"), \
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \ ctx.Attr<std::string>("activation"), D); \
bool is_reverse = ctx.Attr<bool>("is_reverse"); const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \
#define INIT_BASE_SIZES \ const T* wh_data = wh->data<T>(); \
auto x_dims = x->dims(); /* T x M*/ \ auto place = ctx.GetPlace(); \
auto wh_dims = wh->dims(); /* D x 3D*/ \ T* xx_data = xx->mutable_data<T>(place)
const int total_T = x_dims[0]; \
const int M = x_dims[1]; \
const int D = wh_dims[0]; \
const int D3 = wh_dims[1]; \
const int D2 = D * 2;
void SeqCompute(const framework::ExecutionContext& ctx) const { void SeqCompute(const framework::ExecutionContext& ctx) const {
using DeviceContext = paddle::platform::CPUDeviceContext; using DeviceContext = paddle::platform::CPUDeviceContext;
auto* x = ctx.Input<LoDTensor>("X"); INIT_BASE_DEFINES;
INIT_BASE_INPUT_OUTPUT INIT_OTHER_DEFINES;
INIT_BASE_SIZES
INIT_VEC_FUNC
auto x_lod = x->lod();
const int N = x_lod[0].size() - 1; const int N = x_lod[0].size() - 1;
const T* x_data = x->data<T>();
const T* h0_data = h0 ? h0->data<T>() : nullptr; const T* h0_data = h0 ? h0->data<T>() : nullptr;
const T* wx_data = wx->data<T>();
const T* wh_data = wh->data<T>();
const T* wh_state_data = wh_data + D * D2; const T* wh_state_data = wh_data + D * D2;
T* xx_data = xx->mutable_data<T>(ctx.GetPlace()); T* hidden_out_data = hidden_out->mutable_data<T>(place);
T* hidden_out_data = hidden_out->mutable_data<T>(ctx.GetPlace());
auto blas = math::GetBlas<DeviceContext, T>(ctx); auto blas = math::GetBlas<DeviceContext, T>(ctx);
math::FCCompute<DeviceContext, T>(blas, total_T, D3, M, x_data, wx_data, math::FCCompute<DeviceContext, T>(blas, total_T, D3, M, x_data, wx_data,
xx_data, xx_data,
...@@ -252,14 +237,7 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -252,14 +237,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
if (h0_data) { if (h0_data) {
prev_hidden_data = h0_data + bid * D; prev_hidden_data = h0_data + bid * D;
} else { } else {
// W: {W_update, W_reset; W_state} ker->ComputeH1(xx_data, hidden_out_data);
// update gate
act_gate(D, xx_data, xx_data);
// state gate
act_state(D, xx_data + D2, xx_data + D2);
// out = a*b
blas.VMUL(D, xx_data, xx_data + D2, hidden_out_data);
// save prev
prev_hidden_data = hidden_out_data; prev_hidden_data = hidden_out_data;
tstart = 1; tstart = 1;
move_step(); move_step();
...@@ -269,17 +247,12 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -269,17 +247,12 @@ class FusionGRUKernel : public framework::OpKernel<T> {
blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D2, D, static_cast<T>(1), blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D2, D, static_cast<T>(1),
prev_hidden_data, D, wh_data, D2, static_cast<T>(1), xx_data, prev_hidden_data, D, wh_data, D2, static_cast<T>(1), xx_data,
D3); D3);
act_gate(D2, xx_data, xx_data); ker->ComputeHtPart1(xx_data, prev_hidden_data, hidden_out_data);
// rt = rt*ht_1 inplace result
blas.VMUL(D, prev_hidden_data, xx_data + D, hidden_out_data);
// gemm rt * Ws // gemm rt * Ws
blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D, D, static_cast<T>(1), blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D, D, static_cast<T>(1),
hidden_out_data, D, wh_state_data, D, static_cast<T>(1), hidden_out_data, D, wh_state_data, D, static_cast<T>(1),
xx_data + D2, D3); xx_data + D2, D3);
act_state(D, xx_data + D2, xx_data + D2); ker->ComputeHtPart2(xx_data, prev_hidden_data, hidden_out_data);
// out = zt*ht~ + (1-zt)*ht_1
cross(D, xx_data, xx_data + D2, prev_hidden_data, hidden_out_data);
// save prev // save prev
prev_hidden_data = hidden_out_data; prev_hidden_data = hidden_out_data;
move_step(); move_step();
...@@ -289,28 +262,19 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -289,28 +262,19 @@ class FusionGRUKernel : public framework::OpKernel<T> {
void BatchCompute(const framework::ExecutionContext& ctx) const { void BatchCompute(const framework::ExecutionContext& ctx) const {
using DeviceContext = paddle::platform::CPUDeviceContext; using DeviceContext = paddle::platform::CPUDeviceContext;
auto* x = ctx.Input<LoDTensor>("X"); INIT_BASE_DEFINES;
INIT_BASE_INPUT_OUTPUT if (x_lod[0].size() == 2) {
INIT_BASE_SIZES
if (x->lod()[0].size() == 2) {
xx->Resize({total_T, D3}); xx->Resize({total_T, D3});
SeqCompute(ctx); SeqCompute(ctx);
return; return;
} }
INIT_VEC_FUNC INIT_OTHER_DEFINES;
auto* reordered_h0 = ctx.Output<Tensor>("ReorderedH0"); auto* reordered_h0 = ctx.Output<Tensor>("ReorderedH0");
auto* batched_input = ctx.Output<LoDTensor>("BatchedInput"); auto* batched_input = ctx.Output<LoDTensor>("BatchedInput");
auto* batched_out = ctx.Output<LoDTensor>("BatchedOut"); auto* batched_out = ctx.Output<LoDTensor>("BatchedOut");
T* batched_input_data = batched_input->mutable_data<T>(place);
const T* x_data = x->data<T>(); T* batched_out_data = batched_out->mutable_data<T>(place);
const T* wx_data = wx->data<T>(); hidden_out->mutable_data<T>(place);
const T* wh_data = wh->data<T>();
T* xx_data = xx->mutable_data<T>(ctx.GetPlace());
T* batched_input_data = batched_input->mutable_data<T>(ctx.GetPlace());
T* batched_out_data = batched_out->mutable_data<T>(ctx.GetPlace());
hidden_out->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx); auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch; math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
...@@ -336,7 +300,7 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -336,7 +300,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
T* prev_hidden_data = nullptr; T* prev_hidden_data = nullptr;
if (h0) { if (h0) {
// reorder h0 // reorder h0
T* reordered_h0_data = reordered_h0->mutable_data<T>(ctx.GetPlace()); T* reordered_h0_data = reordered_h0->mutable_data<T>(place);
const T* h0_data = h0->data<T>(); const T* h0_data = h0->data<T>();
prev_hidden_data = reordered_h0_data; prev_hidden_data = reordered_h0_data;
size_t sz = sizeof(T) * D; size_t sz = sizeof(T) * D;
...@@ -350,12 +314,7 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -350,12 +314,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
T* cur_out_data = batched_out_data; T* cur_out_data = batched_out_data;
// W: {W_update, W_reset; W_state} // W: {W_update, W_reset; W_state}
for (int i = 0; i < max_bs; ++i) { for (int i = 0; i < max_bs; ++i) {
// update gate ker->ComputeH1(cur_in_data, cur_out_data);
act_gate(D, cur_in_data, cur_in_data);
// state gate
act_state(D, cur_in_data + D2, cur_in_data + D2);
// out = a*b
blas.VMUL(D, cur_in_data, cur_in_data + D2, cur_out_data);
// add offset // add offset
cur_in_data += D3; cur_in_data += D3;
cur_out_data += D; cur_out_data += D;
...@@ -380,10 +339,8 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -380,10 +339,8 @@ class FusionGRUKernel : public framework::OpKernel<T> {
T* cur_out_data = batched_out_data; T* cur_out_data = batched_out_data;
T* cur_prev_hidden_data = prev_hidden_data; T* cur_prev_hidden_data = prev_hidden_data;
for (int i = 0; i < cur_bs; ++i) { for (int i = 0; i < cur_bs; ++i) {
act_gate(D2, cur_batched_data, cur_batched_data); ker->ComputeHtPart1(cur_batched_data, cur_prev_hidden_data,
// rt = rt*ht_1 inplace result cur_out_data);
blas.VMUL(D, cur_prev_hidden_data, cur_batched_data + D, cur_out_data);
cur_batched_data += D3; cur_batched_data += D3;
cur_prev_hidden_data += D; cur_prev_hidden_data += D;
cur_out_data += D; cur_out_data += D;
...@@ -397,12 +354,8 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -397,12 +354,8 @@ class FusionGRUKernel : public framework::OpKernel<T> {
cur_prev_hidden_data = prev_hidden_data; cur_prev_hidden_data = prev_hidden_data;
for (int i = 0; i < cur_bs; ++i) { for (int i = 0; i < cur_bs; ++i) {
// ht~ = act_state(...) ker->ComputeHtPart2(cur_batched_data, cur_prev_hidden_data,
act_state(D, cur_batched_data + D2, cur_batched_data + D2); cur_out_data);
// out = zt*ht~ + (1-zt)*ht_1
cross(D, cur_batched_data, cur_batched_data + D2, cur_prev_hidden_data,
cur_out_data);
cur_batched_data += D3; cur_batched_data += D3;
cur_prev_hidden_data += D; cur_prev_hidden_data += D;
cur_out_data += D; cur_out_data += D;
...@@ -416,9 +369,8 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -416,9 +369,8 @@ class FusionGRUKernel : public framework::OpKernel<T> {
batched_out->set_lod(batched_lod); batched_out->set_lod(batched_lod);
to_seq(dev_ctx, *batched_out, hidden_out); to_seq(dev_ctx, *batched_out, hidden_out);
} }
#undef INIT_VEC_FUNC #undef INIT_OTHER_DEFINES
#undef INIT_BASE_SIZES #undef INIT_BASE_DEFINES
#undef INIT_BASE_INPUT_OUTPUT
}; };
} // namespace operators } // namespace operators
......
...@@ -75,6 +75,6 @@ endif() ...@@ -75,6 +75,6 @@ endif()
cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split) cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split)
cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info) cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info)
cc_library(jit_kernel cc_library(jit_kernel
SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_lstm.cc SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_rnn.cc
DEPS cpu_info cblas) DEPS cpu_info cblas)
cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel) cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel)
...@@ -142,6 +142,15 @@ class LSTMKernel : public Kernel { ...@@ -142,6 +142,15 @@ class LSTMKernel : public Kernel {
const T *wp_data = nullptr) const = 0; const T *wp_data = nullptr) const = 0;
}; };
template <typename T>
class GRUKernel : public Kernel {
public:
// compute h1 without h0
virtual void ComputeH1(T *gates, T *ht) const = 0;
virtual void ComputeHtPart1(T *gates, const T *ht_1, T *ht) const = 0;
virtual void ComputeHtPart2(T *gates, const T *ht_1, T *ht) const = 0;
};
} // namespace jitkernel } // namespace jitkernel
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -136,6 +136,21 @@ static std::shared_ptr<const VActKernel<T>> GetActKernel( ...@@ -136,6 +136,21 @@ static std::shared_ptr<const VActKernel<T>> GetActKernel(
return nullptr; return nullptr;
} }
template <jit::cpu_isa_t isa>
static std::unique_ptr<AVXAct> GetAVXAct(const std::string& type) {
if (type == "sigmoid") {
return std::unique_ptr<AVXAct>(new AVXActImpl<kSigmoid, isa>());
} else if (type == "relu") {
return std::unique_ptr<AVXAct>(new AVXActImpl<kRelu, isa>());
} else if (type == "tanh") {
return std::unique_ptr<AVXAct>(new AVXActImpl<kTanh, isa>());
} else if (type == "identity" || type == "") {
return std::unique_ptr<AVXAct>(new AVXActImpl<kIdentity, isa>());
}
PADDLE_THROW("Not support type: %s", type);
return nullptr;
}
/* LSTM JitKernel */ /* LSTM JitKernel */
template <typename T, jit::cpu_isa_t isa, jit_block> template <typename T, jit::cpu_isa_t isa, jit_block>
class LSTMKernelImpl : public LSTMKernel<T> { class LSTMKernelImpl : public LSTMKernel<T> {
...@@ -192,61 +207,49 @@ class LSTMKernelImpl : public LSTMKernel<T> { ...@@ -192,61 +207,49 @@ class LSTMKernelImpl : public LSTMKernel<T> {
#endif #endif
}; };
#define INTRI8_FLOAT(isa) \ #define INTRI8_FLOAT(isa) \
template <> \ template <> \
LSTMKernelImpl<float, isa, kEQ8>::LSTMKernelImpl( \ LSTMKernelImpl<float, isa, kEQ8>::LSTMKernelImpl( \
const std::string& act_gate, const std::string& act_cand, \ const std::string& act_gate, const std::string& act_cand, \
const std::string& act_cell, int d) \ const std::string& act_cell, int d) \
: LSTMKernel<float>() { \ : LSTMKernel<float>() { \
auto GetAVXAct = [&](const std::string& type) -> std::unique_ptr<AVXAct> { \ avx_act_gate_ = GetAVXAct<isa>(act_gate); \
if (type == "sigmoid") { \ avx_act_cand_ = GetAVXAct<isa>(act_cand); \
return std::unique_ptr<AVXAct>(new AVXActImpl<kSigmoid, isa>()); \ avx_act_cell_ = GetAVXAct<isa>(act_cell); \
} else if (type == "relu") { \ } \
return std::unique_ptr<AVXAct>(new AVXActImpl<kRelu, isa>()); \ template <> \
} else if (type == "tanh") { \ void LSTMKernelImpl<float, isa, kEQ8>::ComputeCtHt( \
return std::unique_ptr<AVXAct>(new AVXActImpl<kTanh, isa>()); \ float* gates, const float* ct_1, float* ct, float* ht, \
} else if (type == "identity" || type == "") { \ const float* wp_data, float* checked) const { \
return std::unique_ptr<AVXAct>(new AVXActImpl<kIdentity, isa>()); \ /* gates: W_ch, W_ih, W_fh, W_oh */ \
} \ __m256 c, i, f, o; \
PADDLE_THROW("Not support type: %s", type); \ c = _mm256_loadu_ps(gates); \
}; \ i = _mm256_loadu_ps(gates + 8); \
avx_act_gate_ = GetAVXAct(act_gate); \ f = _mm256_loadu_ps(gates + 16); \
avx_act_cand_ = GetAVXAct(act_cand); \ o = _mm256_loadu_ps(gates + 24); \
avx_act_cell_ = GetAVXAct(act_cell); \ /* C_t = C_t-1 * fgated + cand_gated * igated*/ \
} \ c = _mm256_mul_ps(avx_act_cand_->Compute(c), avx_act_gate_->Compute(i)); \
template <> \ i = _mm256_loadu_ps(ct_1); \
void LSTMKernelImpl<float, isa, kEQ8>::ComputeCtHt( \ f = _mm256_mul_ps(i, avx_act_gate_->Compute(f)); \
float* gates, const float* ct_1, float* ct, float* ht, \ f = _mm256_add_ps(c, f); \
const float* wp_data, float* checked) const { \ _mm256_storeu_ps(ct, f); \
/* gates: W_ch, W_ih, W_fh, W_oh */ \ /* H_t = act_cell(C_t) * ogated */ \
__m256 c, i, f, o; \ o = _mm256_mul_ps(avx_act_cell_->Compute(f), avx_act_gate_->Compute(o)); \
c = _mm256_loadu_ps(gates); \ _mm256_storeu_ps(ht, o); \
i = _mm256_loadu_ps(gates + 8); \ } \
f = _mm256_loadu_ps(gates + 16); \ template <> \
o = _mm256_loadu_ps(gates + 24); \ void LSTMKernelImpl<float, isa, kEQ8>::ComputeC1H1( \
/* C_t = C_t-1 * fgated + cand_gated * igated*/ \ float* gates, float* ct, float* ht, const float* wp_data) const { \
c = _mm256_mul_ps(avx_act_cand_->Compute(c), avx_act_gate_->Compute(i)); \ __m256 c, i, o; \
i = _mm256_loadu_ps(ct_1); \ c = _mm256_loadu_ps(gates); \
f = _mm256_mul_ps(i, avx_act_gate_->Compute(f)); \ i = _mm256_loadu_ps(gates + 8); \
f = _mm256_add_ps(c, f); \ o = _mm256_loadu_ps(gates + 24); \
_mm256_storeu_ps(ct, f); \ /* C_t = igated * cgated*/ \
/* H_t = act_cell(C_t) * ogated */ \ c = _mm256_mul_ps(avx_act_gate_->Compute(i), avx_act_cand_->Compute(c)); \
o = _mm256_mul_ps(avx_act_cell_->Compute(f), avx_act_gate_->Compute(o)); \ _mm256_storeu_ps(ct, c); \
_mm256_storeu_ps(ht, o); \ /* H_t = act_cell(C_t) * ogated */ \
} \ o = _mm256_mul_ps(avx_act_cell_->Compute(c), avx_act_gate_->Compute(o)); \
template <> \ _mm256_storeu_ps(ht, o); \
void LSTMKernelImpl<float, isa, kEQ8>::ComputeC1H1( \
float* gates, float* ct, float* ht, const float* wp_data) const { \
__m256 c, i, o; \
c = _mm256_loadu_ps(gates); \
i = _mm256_loadu_ps(gates + 8); \
o = _mm256_loadu_ps(gates + 24); \
/* C_t = igated * cgated*/ \
c = _mm256_mul_ps(avx_act_gate_->Compute(i), avx_act_cand_->Compute(c)); \
_mm256_storeu_ps(ct, c); \
/* H_t = act_cell(C_t) * ogated */ \
o = _mm256_mul_ps(avx_act_cell_->Compute(c), avx_act_gate_->Compute(o)); \
_mm256_storeu_ps(ht, o); \
} }
// TODO(TJ): optimize keq16 // TODO(TJ): optimize keq16
...@@ -354,6 +357,126 @@ REGISTER_JITKERNEL_ARGS(lstm, LSTMKernel, JITKERNEL_DECLARE_LSTM, ...@@ -354,6 +357,126 @@ REGISTER_JITKERNEL_ARGS(lstm, LSTMKernel, JITKERNEL_DECLARE_LSTM,
#undef JITKERNEL_DECLARE_LSTM #undef JITKERNEL_DECLARE_LSTM
#undef JITKERNEL_KEY_LSTM #undef JITKERNEL_KEY_LSTM
#undef JITKERNEL_NEW_LSTM_IMPL #undef JITKERNEL_NEW_LSTM_IMPL
/* GRU JitKernel */
template <typename T, jit::cpu_isa_t isa, jit_block>
class GRUKernelImpl : public GRUKernel<T> {
public:
explicit GRUKernelImpl(const std::string& act_gate,
const std::string& act_state, int d)
: GRUKernel<T>() {
d_ = d;
d2_ = d * 2;
act_gate_d2_ = GetActKernel<T>(act_gate, d2_);
act_gate_d_ = GetActKernel<T>(act_gate, d);
act_state_d_ = GetActKernel<T>(act_state, d);
vmul_d_ = KernelPool::Instance().template Get<VMulKernel<T>>(d);
}
void ComputeH1(T* gates, T* ht) const override {
act_gate_d_->Compute(gates, gates);
act_state_d_->Compute(gates + d2_, gates + d2_);
vmul_d_->Compute(gates, gates + d2_, ht);
}
void ComputeHtPart1(T* gates, const T* ht_1, T* ht) const override {
// W: {W_update, W_reset; W_state}
act_gate_d2_->Compute(gates, gates);
vmul_d_->Compute(ht_1, gates + d_, ht);
}
void ComputeHtPart2(T* gates, const T* ht_1, T* ht) const override {
T* y = gates + d2_;
act_state_d_->Compute(y, y);
// out = zt*ht~ + (1-zt)*ht_1
for (int i = 0; i < d_; ++i) {
ht[i] = gates[i] * y[i] + (static_cast<T>(1) - gates[i]) * ht_1[i];
}
}
private:
int d_, d2_;
std::shared_ptr<const VActKernel<T>> act_gate_d2_, act_gate_d_, act_state_d_;
std::shared_ptr<const VMulKernel<T>> vmul_d_;
#ifdef __AVX__
std::unique_ptr<const AVXAct> avx_act_gate_, avx_act_state_;
#endif
};
#define INTRI8_FLOAT(isa) \
template <> \
GRUKernelImpl<float, isa, kEQ8>::GRUKernelImpl( \
const std::string& act_gate, const std::string& act_state, int d) \
: GRUKernel<float>() { \
avx_act_gate_ = GetAVXAct<isa>(act_gate); \
avx_act_state_ = GetAVXAct<isa>(act_state); \
} \
template <> \
void GRUKernelImpl<float, isa, kEQ8>::ComputeH1(float* gates, float* ht) \
const { \
__m256 u, s; \
/* W: {W_update, W_reset; W_state} */ \
u = _mm256_loadu_ps(gates); \
s = _mm256_loadu_ps(gates + 16); \
s = _mm256_mul_ps(avx_act_gate_->Compute(u), avx_act_state_->Compute(s)); \
_mm256_storeu_ps(ht, s); \
} \
template <> \
void GRUKernelImpl<float, isa, kEQ8>::ComputeHtPart1( \
float* gates, const float* ht_1, float* ht) const { \
/* not exactly equal the any implementation */ \
__m256 r, ht0; \
r = _mm256_loadu_ps(gates + 8); \
ht0 = _mm256_loadu_ps(ht_1); \
r = _mm256_mul_ps(avx_act_gate_->Compute(r), ht0); \
_mm256_storeu_ps(ht, r); \
} \
template <> \
void GRUKernelImpl<float, isa, kEQ8>::ComputeHtPart2( \
float* gates, const float* ht_1, float* ht) const { \
/* not exactly equal the any implementation */ \
__m256 u, s, ht0; \
u = _mm256_loadu_ps(gates); \
s = _mm256_loadu_ps(gates + 16); \
ht0 = _mm256_loadu_ps(ht_1); \
u = avx_act_gate_->Compute(u); \
s = _mm256_mul_ps(u, avx_act_state_->Compute(s)); \
u = _mm256_sub_ps(_mm256_set1_ps(1.f), u); \
u = _mm256_mul_ps(u, ht0); \
u = _mm256_add_ps(s, u); \
_mm256_storeu_ps(ht, u); \
}
#ifdef __AVX__
INTRI8_FLOAT(jit::avx);
#endif
#ifdef __AVX2__
INTRI8_FLOAT(jit::avx2);
#endif
#ifdef __AVX512F__
INTRI8_FLOAT(jit::avx512f);
#endif
#define JITKERNEL_DECLARE_GRU(ker_class, ker_dtype) \
template <> \
std::shared_ptr<const GRUKernel<ker_dtype>> KernelPool::Get< \
GRUKernel<ker_dtype>, const std::string&, const std::string&, int>( \
const std::string& act_gate, const std::string& act_state, int d)
#define JITKERNEL_KEY_GRU(ker_key, dtype_key) \
#ker_key #dtype_key + std::to_string(d) + act_gate + act_state
#define JITKERNEL_NEW_GRU_IMPL(ker, dtype, isa, k) \
p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<ker##Impl<dtype, isa, k>>(act_gate, act_state, d));
REGISTER_JITKERNEL_ARGS(gru, GRUKernel, JITKERNEL_DECLARE_GRU,
JITKERNEL_KEY_GRU, JITKERNEL_NEW_GRU_IMPL);
#undef INTRI8_FLOAT
#undef JITKERNEL_NEW_GRU_IMPL
#undef JITKERNEL_KEY_GRU
#undef JITKERNEL_DECLARE_GRU
} // namespace jitkernel } // namespace jitkernel
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -125,6 +125,12 @@ class TestFusionGRUOpMD2(TestFusionGRUOp): ...@@ -125,6 +125,12 @@ class TestFusionGRUOpMD2(TestFusionGRUOp):
self.D = 8 self.D = 8
class TestFusionGRUOpMD3(TestFusionGRUOp):
def set_confs(self):
self.M = 17
self.D = 15
class TestFusionGRUOpBS1(TestFusionGRUOp): class TestFusionGRUOpBS1(TestFusionGRUOp):
def set_confs(self): def set_confs(self):
self.lod = [[3]] self.lod = [[3]]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册