提交 a79a77ee 编写于 作者: T tensor-tang

refine and clean code

上级 c459fb5b
...@@ -215,46 +215,53 @@ This operator fuse the X into LSTM, more details can refer to LSTM op. ...@@ -215,46 +215,53 @@ This operator fuse the X into LSTM, more details can refer to LSTM op.
template <typename T> template <typename T>
class FuisonLSTMKernel : public framework::OpKernel<T> { class FuisonLSTMKernel : public framework::OpKernel<T> {
public: public:
void SeqCompute(const framework::ExecutionContext& ctx) const { #define INIT_VEC_FUNC \
using DeviceContext = paddle::platform::CPUDeviceContext; std::function<void(const int, const T *, T *)> act_gate, act_cell, act_cand; \
auto* x = ctx.Input<LoDTensor>("X"); auto& act_gate_str = ctx.Attr<std::string>("gate_activation"); \
auto* h0 = ctx.Input<Tensor>("H0"); auto& act_cell_str = ctx.Attr<std::string>("cell_activation"); \
auto* c0 = ctx.Input<Tensor>("C0"); auto& act_cand_str = ctx.Attr<std::string>("candidate_activation"); \
auto* wx = ctx.Input<Tensor>("WeightX"); if (platform::jit::MayIUse(platform::jit::avx)) { \
auto* wh = ctx.Input<Tensor>("WeightH"); math::VecActivations<T, platform::jit::avx> act_functor; \
auto* bias = ctx.Input<Tensor>("Bias"); act_gate = act_functor(act_gate_str); \
act_cell = act_functor(act_cell_str); \
auto* xx = ctx.Output<LoDTensor>("XX"); act_cand = act_functor(act_cand_str); \
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); } else { \
auto* cell_out = ctx.Output<LoDTensor>("Cell"); math::VecActivations<T, platform::jit::isa_any> act_functor; \
act_gate = act_functor(act_gate_str); \
act_cell = act_functor(act_cell_str); \
act_cand = act_functor(act_cand_str); \
}
#define INIT_BASE_INPUT_OUTPUT \
auto* x = ctx.Input<LoDTensor>("X"); \
auto* h0 = ctx.Input<Tensor>("H0"); \
auto* c0 = ctx.Input<Tensor>("C0"); \
auto* wx = ctx.Input<Tensor>("WeightX"); \
auto* wh = ctx.Input<Tensor>("WeightH"); \
auto* bias = ctx.Input<Tensor>("Bias"); \
auto* xx = ctx.Output<LoDTensor>("XX"); \
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
auto* cell_out = ctx.Output<LoDTensor>("Cell"); \
bool is_reverse = ctx.Attr<bool>("is_reverse"); bool is_reverse = ctx.Attr<bool>("is_reverse");
std::function<void(const int, const T *, T *)> act_gate, act_cell, act_cand; #define INIT_BASE_SIZES \
auto& act_gate_str = ctx.Attr<std::string>("gate_activation"); auto x_dims = x->dims(); /* T x M*/ \
auto& act_cell_str = ctx.Attr<std::string>("cell_activation"); auto wh_dims = wh->dims(); /* D x 4D*/ \
auto& act_cand_str = ctx.Attr<std::string>("candidate_activation"); const int M = x_dims[1]; \
if (platform::jit::MayIUse(platform::jit::avx)) { const int D = wh_dims[0]; \
math::VecActivations<T, platform::jit::avx> act_functor; const int D2 = D * 2; \
act_gate = act_functor(act_gate_str); const int D3 = D * 3; \
act_cell = act_functor(act_cell_str); const int D4 = wh_dims[1];
act_cand = act_functor(act_cand_str);
} else { void SeqCompute(const framework::ExecutionContext& ctx) const {
math::VecActivations<T, platform::jit::isa_any> act_functor; using DeviceContext = paddle::platform::CPUDeviceContext;
act_gate = act_functor(act_gate_str); INIT_BASE_INPUT_OUTPUT
act_cell = act_functor(act_cell_str); INIT_BASE_SIZES
act_cand = act_functor(act_cand_str); INIT_VEC_FUNC
}
auto x_lod = x->lod(); auto x_lod = x->lod();
auto x_dims = x->dims(); // T x M
auto wh_dims = wh->dims(); // D x 4D
const int total_T = x_dims[0]; const int total_T = x_dims[0];
const int N = x_lod[0].size() - 1; // batch size const int N = x_lod[0].size() - 1; // batch size
const int M = x_dims[1]; // x frame size
const int D = wh_dims[0];
const int D2 = D * 2;
const int D3 = D * 3;
const int D4 = wh_dims[1];
const T* x_data = x->data<T>(); const T* x_data = x->data<T>();
const T* h0_data = h0 ? h0->data<T>() : NULL; const T* h0_data = h0 ? h0->data<T>() : NULL;
...@@ -343,52 +350,18 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -343,52 +350,18 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
void BatchCompute(const framework::ExecutionContext& ctx) const { void BatchCompute(const framework::ExecutionContext& ctx) const {
using DeviceContext = platform::CPUDeviceContext; using DeviceContext = platform::CPUDeviceContext;
auto* x = ctx.Input<LoDTensor>("X"); INIT_BASE_INPUT_OUTPUT
auto* wx = ctx.Input<Tensor>("WeightX"); if (x->lod()[0].size() == 2) { // batch size == 1
auto* wh = ctx.Input<Tensor>("WeightH"); SeqCompute(ctx);
auto* bias = ctx.Input<Tensor>("Bias"); }
auto* h0 = ctx.Input<Tensor>("H0"); INIT_BASE_SIZES
auto* c0 = ctx.Input<Tensor>("C0"); INIT_VEC_FUNC
auto* xx = ctx.Output<LoDTensor>("XX");
auto* reordered_h0 = ctx.Output<Tensor>("ReorderedH0"); auto* reordered_h0 = ctx.Output<Tensor>("ReorderedH0");
auto* reordered_c0 = ctx.Output<Tensor>("ReorderedC0"); auto* reordered_c0 = ctx.Output<Tensor>("ReorderedC0");
auto* batched_input = ctx.Output<LoDTensor>("BatchedInput"); auto* batched_input = ctx.Output<LoDTensor>("BatchedInput");
auto* batched_c_out = ctx.Output<LoDTensor>("BatchedCell"); auto* batched_c_out = ctx.Output<LoDTensor>("BatchedCell");
auto* batched_h_out = ctx.Output<LoDTensor>("BatchedHidden"); auto* batched_h_out = ctx.Output<LoDTensor>("BatchedHidden");
auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
auto* cell_out = ctx.Output<LoDTensor>("Cell");
bool is_reverse = ctx.Attr<bool>("is_reverse");
std::function<void(const int, const T *, T *)> act_gate, act_cell, act_cand;
auto& act_gate_str = ctx.Attr<std::string>("gate_activation");
auto& act_cell_str = ctx.Attr<std::string>("cell_activation");
auto& act_cand_str = ctx.Attr<std::string>("candidate_activation");
if (platform::jit::MayIUse(platform::jit::avx)) {
math::VecActivations<T, platform::jit::avx> act_functor;
act_gate = act_functor(act_gate_str);
act_cell = act_functor(act_cell_str);
act_cand = act_functor(act_cand_str);
} else {
math::VecActivations<T, platform::jit::isa_any> act_functor;
act_gate = act_functor(act_gate_str);
act_cell = act_functor(act_cell_str);
act_cand = act_functor(act_cand_str);
}
auto x_dims = x->dims(); // T x M
auto wh_dims = wh->dims(); // D x 4D
// auto x_lod = x->lod();
// const int N = x_lod[0].size() - 1; // batch size
// if (N == 1) {
// SeqCompute(ctx);
// }
const int M = x_dims[1];
const int D = wh_dims[0];
const int D2 = D * 2;
const int D3 = D * 3;
const int D4 = wh_dims[1];
const T* x_data = x->data<T>(); const T* x_data = x->data<T>();
const T* wx_data = wx->data<T>(); const T* wx_data = wx->data<T>();
...@@ -485,16 +458,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -485,16 +458,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
// W_ch, W_ih, W_fh, W_oh // W_ch, W_ih, W_fh, W_oh
act_gate(D3, cur_in_data + D, cur_in_data + D); act_gate(D3, cur_in_data + D, cur_in_data + D);
act_cand(D, cur_in_data, cur_in_data); act_cand(D, cur_in_data, cur_in_data);
// a = forget * prev_cell // a = forget * prev_cell
blas.VMUL(D, cur_in_data + D2, cur_prev_c_data, cur_in_data + D2); blas.VMUL(D, cur_in_data + D2, cur_prev_c_data, cur_in_data + D2);
// b = input * tilde // b = input * tilde
blas.VMUL(D, cur_in_data, cur_in_data + D, cur_in_data + D); blas.VMUL(D, cur_in_data, cur_in_data + D, cur_in_data + D);
// cell out= a+b // cell out= a+b
blas.VADD(D, cur_in_data + D, cur_in_data + D2, cur_c_out_data); blas.VADD(D, cur_in_data + D, cur_in_data + D2, cur_c_out_data);
// hidden out= act_state(cellout) * outgate // hidden out= act_state(cellout) * outgate
act_cell(D, cur_c_out_data, cur_in_data + D2); act_cell(D, cur_c_out_data, cur_in_data + D2);
blas.VMUL(D, cur_in_data + D2, cur_in_data + D3, cur_h_out_data); blas.VMUL(D, cur_in_data + D2, cur_in_data + D3, cur_h_out_data);
...@@ -526,6 +495,9 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -526,6 +495,9 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
BatchCompute(ctx); BatchCompute(ctx);
} }
} }
#undef INIT_BASE_SIZES
#undef INIT_BASE_INPUT_OUTPUT
#undef INIT_VEC_FUNC
}; };
} // namespace operators } // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册