提交 59621390 编写于 作者: T tensor-tang

add gru seq mode forward

上级 b0d36c4c
...@@ -21,6 +21,8 @@ limitations under the License. */ ...@@ -21,6 +21,8 @@ limitations under the License. */
#include "paddle/fluid/operators/math/sequence2batch.h" #include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
DEFINE_bool(gru_use_seq, true, "Use sequence mode");
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -84,7 +86,12 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -84,7 +86,12 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
ctx->SetOutputDim("BatchedOut", out_dims); ctx->SetOutputDim("BatchedOut", out_dims);
ctx->ShareLoD("X", "Hidden"); ctx->ShareLoD("X", "Hidden");
int xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1]; int xx_width;
if (FLAGS_gru_use_seq) {
xx_width = wx_dims[1];
} else {
xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
}
ctx->SetOutputDim("XX", {x_dims[0], xx_width}); ctx->SetOutputDim("XX", {x_dims[0], xx_width});
ctx->ShareLoD("X", "XX"); ctx->ShareLoD("X", "XX");
} }
...@@ -157,6 +164,122 @@ template <typename T> ...@@ -157,6 +164,122 @@ template <typename T>
class FusionGRUKernel : public framework::OpKernel<T> { class FusionGRUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
if (FLAGS_gru_use_seq) {
SeqCompute(ctx);
} else {
BatchCompute(ctx);
}
}
#define INIT_VEC_FUNC \
std::function<void(const int, const T *, T *)> act_gate, act_state; \
std::function<void(const int, const T*, const T*, const T*, T*)> cross; \
auto& act_gate_str = ctx.Attr<std::string>("gate_activation"); \
auto& act_state_str = ctx.Attr<std::string>("activation"); \
if (platform::jit::MayIUse(platform::jit::avx)) { \
math::VecActivations<T, platform::jit::avx> act_functor; \
act_gate = act_functor(act_gate_str); \
act_state = act_functor(act_state_str); \
cross = math::vec_cross<T, platform::jit::avx>; \
} else { \
math::VecActivations<T, platform::jit::isa_any> act_functor; \
act_gate = act_functor(act_gate_str); \
act_state = act_functor(act_state_str); \
cross = math::vec_cross<T, platform::jit::isa_any>; \
}
void SeqCompute(const framework::ExecutionContext& ctx) const {
using DeviceContext = paddle::platform::CPUDeviceContext;
auto* x = ctx.Input<LoDTensor>("X");
auto* h0 = ctx.Input<Tensor>("H0");
auto* wx = ctx.Input<Tensor>("WeightX");
auto* wh = ctx.Input<Tensor>("WeightH");
auto* bias = ctx.Input<Tensor>("Bias");
auto* xx = ctx.Output<LoDTensor>("XX");
auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
bool is_reverse = ctx.Attr<bool>("is_reverse");
INIT_VEC_FUNC
auto x_lod = x->lod();
auto x_dims = x->dims(); // T x M
auto wh_dims = wh->dims(); // D x 3D
const int N = x_lod[0].size() - 1;
const int total_T = x_dims[0];
const int M = x_dims[1];
const int D3 = wh_dims[1];
const int D = wh_dims[0];
const int D2 = D * 2;
const T* x_data = x->data<T>();
const T* h0_data = h0 ? h0->data<T>() : NULL;
const T* wx_data = wx->data<T>();
const T* wh_data = wh->data<T>();
const T* wh_state_data = wh_data + D * D2;
T* xx_data = xx->mutable_data<T>(ctx.GetPlace());
T* hidden_out_data = hidden_out->mutable_data<T>(ctx.GetPlace());
auto blas = math::GetBlas<DeviceContext, T>(ctx);
math::FCCompute<DeviceContext, T>(blas, total_T, D3, M, x_data, wx_data,
xx_data, bias ? bias->data<T>() : NULL);
int xx_offset = D3;
int gate_offset = D;
if (is_reverse) {
const int offset = (total_T - 1) * D;
xx_data = xx_data + offset * 3;
hidden_out_data = hidden_out_data + offset;
xx_offset = -D3;
gate_offset = -D;
}
auto move_step = [&]() {
xx_data = xx_data + xx_offset;
hidden_out_data = hidden_out_data + gate_offset;
};
for (int i = 0; i < N; ++i) {
int bid = is_reverse ? N - 1 - i : i;
int seq_len = x_lod[0][bid + 1] - x_lod[0][bid];
const T* prev_hidden_data = NULL;
int tstart = 0;
if (h0_data) {
prev_hidden_data = h0_data + bid * D;
} else {
// W: {W_update, W_reset; W_state}
// update gate
act_gate(D, xx_data, xx_data);
// state gate
act_state(D, xx_data + D2, xx_data + D2);
// out = a*b
blas.VMUL(D, xx_data, xx_data + D2, hidden_out_data);
// save prev
prev_hidden_data = hidden_out_data;
tstart = 1;
move_step();
}
for (int step = tstart; step < seq_len; ++step) {
// gemm prev * (Wu + Wr)
blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D2, D, static_cast<T>(1),
prev_hidden_data, D, wh_data, D2, static_cast<T>(1), xx_data,
D3);
act_gate(D2, xx_data, xx_data);
// rt = rt*ht_1 inplace result
blas.VMUL(D, prev_hidden_data, xx_data + D, hidden_out_data);
// gemm rt * Ws
blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D, D, static_cast<T>(1),
hidden_out_data, D, wh_state_data, D, static_cast<T>(1),
xx_data + D2, D3);
act_state(D, xx_data + D2, xx_data + D2);
// out = zt*ht~ + (1-zt)*ht_1
cross(D, xx_data, xx_data + D2, prev_hidden_data, hidden_out_data);
// save prev
prev_hidden_data = hidden_out_data;
move_step();
}
}
}
void BatchCompute(const framework::ExecutionContext& ctx) const {
using DeviceContext = paddle::platform::CPUDeviceContext; using DeviceContext = paddle::platform::CPUDeviceContext;
auto* x = ctx.Input<LoDTensor>("X"); auto* x = ctx.Input<LoDTensor>("X");
auto* wx = ctx.Input<Tensor>("WeightX"); auto* wx = ctx.Input<Tensor>("WeightX");
...@@ -171,21 +294,7 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -171,21 +294,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
bool is_reverse = ctx.Attr<bool>("is_reverse"); bool is_reverse = ctx.Attr<bool>("is_reverse");
std::function<void(const int, const T *, T *)> act_gate, act_state; INIT_VEC_FUNC
std::function<void(const int, const T*, const T*, const T*, T*)> cross;
auto& act_gate_str = ctx.Attr<std::string>("gate_activation");
auto& act_state_str = ctx.Attr<std::string>("activation");
if (platform::jit::MayIUse(platform::jit::avx)) {
math::VecActivations<T, platform::jit::avx> act_functor;
act_gate = act_functor(act_gate_str);
act_state = act_functor(act_state_str);
cross = math::vec_cross<T, platform::jit::avx>;
} else {
math::VecActivations<T, platform::jit::isa_any> act_functor;
act_gate = act_functor(act_gate_str);
act_state = act_functor(act_state_str);
cross = math::vec_cross<T, platform::jit::isa_any>;
}
const T* x_data = x->data<T>(); const T* x_data = x->data<T>();
const T* wx_data = wx->data<T>(); const T* wx_data = wx->data<T>();
...@@ -305,6 +414,7 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -305,6 +414,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
batched_out->set_lod(batched_lod); batched_out->set_lod(batched_lod);
to_seq(dev_ctx, *batched_out, hidden_out); to_seq(dev_ctx, *batched_out, hidden_out);
} }
#undef INIT_VEC_FUNC
}; };
} // namespace operators } // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册