提交 640e789d 编写于 作者: T tensor-tang

add fusion gru jit kernel

上级 42aa1d40
......@@ -16,10 +16,9 @@ limitations under the License. */
#include <cstring> // for memcpy
#include <string>
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/jit_kernel.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
namespace operators {
......@@ -174,58 +173,44 @@ class FusionGRUKernel : public framework::OpKernel<T> {
}
}
#define INIT_VEC_FUNC \
std::function<void(const int, const T *, T *)> act_gate, act_state; \
std::function<void(const int, const T*, const T*, const T*, T*)> cross; \
auto& act_gate_str = ctx.Attr<std::string>("gate_activation"); \
auto& act_state_str = ctx.Attr<std::string>("activation"); \
if (platform::jit::MayIUse(platform::jit::avx)) { \
math::VecActivations<T, platform::jit::avx> act_functor; \
act_gate = act_functor(act_gate_str); \
act_state = act_functor(act_state_str); \
cross = math::vec_cross<T, platform::jit::avx>; \
} else { \
math::VecActivations<T, platform::jit::isa_any> act_functor; \
act_gate = act_functor(act_gate_str); \
act_state = act_functor(act_state_str); \
cross = math::vec_cross<T, platform::jit::isa_any>; \
}
#define INIT_BASE_INPUT_OUTPUT \
auto* h0 = ctx.Input<Tensor>("H0"); \
auto* wx = ctx.Input<Tensor>("WeightX"); \
auto* wh = ctx.Input<Tensor>("WeightH"); \
auto* bias = ctx.Input<Tensor>("Bias"); \
auto* xx = ctx.Output<LoDTensor>("XX"); \
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
bool is_reverse = ctx.Attr<bool>("is_reverse");
#define INIT_BASE_SIZES \
auto x_dims = x->dims(); /* T x M*/ \
auto wh_dims = wh->dims(); /* D x 3D*/ \
const int total_T = x_dims[0]; \
const int M = x_dims[1]; \
const int D = wh_dims[0]; \
const int D3 = wh_dims[1]; \
const int D2 = D * 2;
#define INIT_BASE_DEFINES \
auto* x = ctx.Input<LoDTensor>("X"); \
auto* wh = ctx.Input<Tensor>("WeightH"); \
auto* xx = ctx.Output<LoDTensor>("XX"); \
auto x_lod = x->lod(); \
auto x_dims = x->dims(); /* T x M*/ \
auto wh_dims = wh->dims(); /* D x 3D*/ \
const int total_T = x_dims[0]; \
const int D3 = wh_dims[1]
#define INIT_OTHER_DEFINES \
auto* h0 = ctx.Input<Tensor>("H0"); \
auto* wx = ctx.Input<Tensor>("WeightX"); \
auto* bias = ctx.Input<Tensor>("Bias"); \
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
bool is_reverse = ctx.Attr<bool>("is_reverse"); \
const int M = x_dims[1]; \
const int D = wh_dims[0]; \
const int D2 = D * 2; \
const auto& ker = math::jitkernel::KernelPool::Instance() \
.template Get<math::jitkernel::GRUKernel<T>, \
const std::string&, const std::string&>( \
ctx.Attr<std::string>("gate_activation"), \
ctx.Attr<std::string>("activation"), D); \
const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \
auto place = ctx.GetPlace(); \
T* xx_data = xx->mutable_data<T>(place)
void SeqCompute(const framework::ExecutionContext& ctx) const {
using DeviceContext = paddle::platform::CPUDeviceContext;
auto* x = ctx.Input<LoDTensor>("X");
INIT_BASE_INPUT_OUTPUT
INIT_BASE_SIZES
INIT_VEC_FUNC
auto x_lod = x->lod();
INIT_BASE_DEFINES;
INIT_OTHER_DEFINES;
const int N = x_lod[0].size() - 1;
const T* x_data = x->data<T>();
const T* h0_data = h0 ? h0->data<T>() : nullptr;
const T* wx_data = wx->data<T>();
const T* wh_data = wh->data<T>();
const T* wh_state_data = wh_data + D * D2;
T* xx_data = xx->mutable_data<T>(ctx.GetPlace());
T* hidden_out_data = hidden_out->mutable_data<T>(ctx.GetPlace());
T* hidden_out_data = hidden_out->mutable_data<T>(place);
auto blas = math::GetBlas<DeviceContext, T>(ctx);
math::FCCompute<DeviceContext, T>(blas, total_T, D3, M, x_data, wx_data,
xx_data,
......@@ -252,14 +237,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
if (h0_data) {
prev_hidden_data = h0_data + bid * D;
} else {
// W: {W_update, W_reset; W_state}
// update gate
act_gate(D, xx_data, xx_data);
// state gate
act_state(D, xx_data + D2, xx_data + D2);
// out = a*b
blas.VMUL(D, xx_data, xx_data + D2, hidden_out_data);
// save prev
ker->ComputeH1(xx_data, hidden_out_data);
prev_hidden_data = hidden_out_data;
tstart = 1;
move_step();
......@@ -269,17 +247,12 @@ class FusionGRUKernel : public framework::OpKernel<T> {
blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D2, D, static_cast<T>(1),
prev_hidden_data, D, wh_data, D2, static_cast<T>(1), xx_data,
D3);
act_gate(D2, xx_data, xx_data);
// rt = rt*ht_1 inplace result
blas.VMUL(D, prev_hidden_data, xx_data + D, hidden_out_data);
ker->ComputeHtPart1(xx_data, prev_hidden_data, hidden_out_data);
// gemm rt * Ws
blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D, D, static_cast<T>(1),
hidden_out_data, D, wh_state_data, D, static_cast<T>(1),
xx_data + D2, D3);
act_state(D, xx_data + D2, xx_data + D2);
// out = zt*ht~ + (1-zt)*ht_1
cross(D, xx_data, xx_data + D2, prev_hidden_data, hidden_out_data);
ker->ComputeHtPart2(xx_data, prev_hidden_data, hidden_out_data);
// save prev
prev_hidden_data = hidden_out_data;
move_step();
......@@ -289,28 +262,19 @@ class FusionGRUKernel : public framework::OpKernel<T> {
void BatchCompute(const framework::ExecutionContext& ctx) const {
using DeviceContext = paddle::platform::CPUDeviceContext;
auto* x = ctx.Input<LoDTensor>("X");
INIT_BASE_INPUT_OUTPUT
INIT_BASE_SIZES
if (x->lod()[0].size() == 2) {
INIT_BASE_DEFINES;
if (x_lod[0].size() == 2) {
xx->Resize({total_T, D3});
SeqCompute(ctx);
return;
}
INIT_VEC_FUNC
INIT_OTHER_DEFINES;
auto* reordered_h0 = ctx.Output<Tensor>("ReorderedH0");
auto* batched_input = ctx.Output<LoDTensor>("BatchedInput");
auto* batched_out = ctx.Output<LoDTensor>("BatchedOut");
const T* x_data = x->data<T>();
const T* wx_data = wx->data<T>();
const T* wh_data = wh->data<T>();
T* xx_data = xx->mutable_data<T>(ctx.GetPlace());
T* batched_input_data = batched_input->mutable_data<T>(ctx.GetPlace());
T* batched_out_data = batched_out->mutable_data<T>(ctx.GetPlace());
hidden_out->mutable_data<T>(ctx.GetPlace());
T* batched_input_data = batched_input->mutable_data<T>(place);
T* batched_out_data = batched_out->mutable_data<T>(place);
hidden_out->mutable_data<T>(place);
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
......@@ -336,7 +300,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
T* prev_hidden_data = nullptr;
if (h0) {
// reorder h0
T* reordered_h0_data = reordered_h0->mutable_data<T>(ctx.GetPlace());
T* reordered_h0_data = reordered_h0->mutable_data<T>(place);
const T* h0_data = h0->data<T>();
prev_hidden_data = reordered_h0_data;
size_t sz = sizeof(T) * D;
......@@ -350,12 +314,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
T* cur_out_data = batched_out_data;
// W: {W_update, W_reset; W_state}
for (int i = 0; i < max_bs; ++i) {
// update gate
act_gate(D, cur_in_data, cur_in_data);
// state gate
act_state(D, cur_in_data + D2, cur_in_data + D2);
// out = a*b
blas.VMUL(D, cur_in_data, cur_in_data + D2, cur_out_data);
ker->ComputeH1(cur_in_data, cur_out_data);
// add offset
cur_in_data += D3;
cur_out_data += D;
......@@ -380,10 +339,8 @@ class FusionGRUKernel : public framework::OpKernel<T> {
T* cur_out_data = batched_out_data;
T* cur_prev_hidden_data = prev_hidden_data;
for (int i = 0; i < cur_bs; ++i) {
act_gate(D2, cur_batched_data, cur_batched_data);
// rt = rt*ht_1 inplace result
blas.VMUL(D, cur_prev_hidden_data, cur_batched_data + D, cur_out_data);
ker->ComputeHtPart1(cur_batched_data, cur_prev_hidden_data,
cur_out_data);
cur_batched_data += D3;
cur_prev_hidden_data += D;
cur_out_data += D;
......@@ -397,12 +354,8 @@ class FusionGRUKernel : public framework::OpKernel<T> {
cur_prev_hidden_data = prev_hidden_data;
for (int i = 0; i < cur_bs; ++i) {
// ht~ = act_state(...)
act_state(D, cur_batched_data + D2, cur_batched_data + D2);
// out = zt*ht~ + (1-zt)*ht_1
cross(D, cur_batched_data, cur_batched_data + D2, cur_prev_hidden_data,
cur_out_data);
ker->ComputeHtPart2(cur_batched_data, cur_prev_hidden_data,
cur_out_data);
cur_batched_data += D3;
cur_prev_hidden_data += D;
cur_out_data += D;
......@@ -416,9 +369,8 @@ class FusionGRUKernel : public framework::OpKernel<T> {
batched_out->set_lod(batched_lod);
to_seq(dev_ctx, *batched_out, hidden_out);
}
#undef INIT_VEC_FUNC
#undef INIT_BASE_SIZES
#undef INIT_BASE_INPUT_OUTPUT
#undef INIT_OTHER_DEFINES
#undef INIT_BASE_DEFINES
};
} // namespace operators
......
......@@ -75,6 +75,6 @@ endif()
cc_test(concat_test SRCS concat_test.cc DEPS concat)
cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info)
cc_library(jit_kernel
SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_lstm.cc
SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_rnn.cc
DEPS cpu_info cblas)
cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel)
......@@ -142,6 +142,15 @@ class LSTMKernel : public Kernel {
const T *wp_data = nullptr) const = 0;
};
template <typename T>
class GRUKernel : public Kernel {
public:
// compute h1 without h0
virtual void ComputeH1(T *gates, T *ht) const = 0;
virtual void ComputeHtPart1(T *gates, const T *ht_1, T *ht) const = 0;
virtual void ComputeHtPart2(T *gates, const T *ht_1, T *ht) const = 0;
};
} // namespace jitkernel
} // namespace math
} // namespace operators
......
......@@ -354,6 +354,67 @@ REGISTER_JITKERNEL_ARGS(lstm, LSTMKernel, JITKERNEL_DECLARE_LSTM,
#undef JITKERNEL_DECLARE_LSTM
#undef JITKERNEL_KEY_LSTM
#undef JITKERNEL_NEW_LSTM_IMPL
/* GRU JitKernel */
template <typename T, jit::cpu_isa_t isa, jit_block>
class GRUKernelImpl : public GRUKernel<T> {
public:
explicit GRUKernelImpl(const std::string& act_gate,
const std::string& act_state, int d)
: GRUKernel<T>() {
d_ = d;
d2_ = d * 2;
act_gate_d2_ = GetActKernel<T>(act_gate, d2_);
act_gate_d_ = GetActKernel<T>(act_gate, d);
act_state_d_ = GetActKernel<T>(act_state, d);
vmul_d_ = KernelPool::Instance().template Get<VMulKernel<T>>(d);
}
void ComputeH1(T* gates, T* ht) const override {
act_gate_d_->Compute(gates, gates);
act_state_d_->Compute(gates + d2_, gates + d2_);
vmul_d_->Compute(gates, gates + d2_, ht);
}
void ComputeHtPart1(T* gates, const T* ht_1, T* ht) const override {
// W: {W_update, W_reset; W_state}
act_gate_d2_->Compute(gates, gates);
vmul_d_->Compute(ht_1, gates + d_, ht);
}
void ComputeHtPart2(T* gates, const T* ht_1, T* ht) const override {
T* y = gates + d2_;
act_state_d_->Compute(y, y);
// out = zt*ht~ + (1-zt)*ht_1
for (int i = 0; i < d_; ++i) {
ht[i] = gates[i] * y[i] + (static_cast<T>(1) - gates[i]) * ht_1[i];
}
}
private:
int d_, d2_;
std::shared_ptr<const VActKernel<T>> act_gate_d2_, act_gate_d_, act_state_d_;
std::shared_ptr<const VMulKernel<T>> vmul_d_;
};
#define JITKERNEL_DECLARE_GRU(ker_class, ker_dtype) \
template <> \
std::shared_ptr<const GRUKernel<ker_dtype>> KernelPool::Get< \
GRUKernel<ker_dtype>, const std::string&, const std::string&, int>( \
const std::string& act_gate, const std::string& act_state, int d)
#define JITKERNEL_KEY_GRU(ker_key, dtype_key) \
#ker_key #dtype_key + std::to_string(d) + act_gate + act_state
#define JITKERNEL_NEW_GRU_IMPL(ker, dtype, isa, k) \
p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<ker##Impl<dtype, isa, k>>(act_gate, act_state, d));
REGISTER_JITKERNEL_ARGS(gru, GRUKernel, JITKERNEL_DECLARE_GRU,
JITKERNEL_KEY_GRU, JITKERNEL_NEW_GRU_IMPL);
#undef JITKERNEL_NEW_GRU_IMPL
#undef JITKERNEL_KEY_GRU
#undef JITKERNEL_DECLARE_GRU
} // namespace jitkernel
} // namespace math
} // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册