提交 4b28fab8 编写于 作者: T tensor-tang

enable more acts

上级 607c4195
...@@ -230,6 +230,22 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -230,6 +230,22 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
auto* cell_out = ctx.Output<LoDTensor>("Cell"); auto* cell_out = ctx.Output<LoDTensor>("Cell");
std::function<void(const int, const T *, T *)> act_gate, act_cell, act_cand;
auto& act_gate_str = ctx.Attr<std::string>("gate_activation");
auto& act_cell_str = ctx.Attr<std::string>("cell_activation");
auto& act_cand_str = ctx.Attr<std::string>("candidate_activation");
if (platform::jit::MayIUse(platform::jit::avx)) {
math::VecActivations<T, platform::jit::avx> act_functor;
act_gate = act_functor(act_gate_str);
act_cell = act_functor(act_cell_str);
act_cand = act_functor(act_cand_str);
} else {
math::VecActivations<T, platform::jit::isa_any> act_functor;
act_gate = act_functor(act_gate_str);
act_cell = act_functor(act_cell_str);
act_cand = act_functor(act_cand_str);
}
auto x_lod = x->lod(); auto x_lod = x->lod();
auto x_dims = x->dims(); // T x M auto x_dims = x->dims(); // T x M
auto wh_dims = wh->dims(); // D x 4D auto wh_dims = wh->dims(); // D x 4D
...@@ -263,15 +279,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -263,15 +279,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
prev_cell_data = c0_data + i * D; prev_cell_data = c0_data + i * D;
} else { } else {
// W_ch, W_ih, W_fh, W_oh // W_ch, W_ih, W_fh, W_oh
// actgate act_gate(D3, xx_data + D, xx_data + D);
math::vec_sigmoid<T>(D3, xx_data + D, xx_data + D); act_cand(D, xx_data, xx_data);
// ch gate
math::vec_tanh<T>(D, xx_data, xx_data);
// cell out= input*tilde // cell out= input*tilde
blas.VMUL(D, xx_data, xx_data + D, cell_out_data); blas.VMUL(D, xx_data, xx_data + D, cell_out_data);
// hidden out= act_state(cellout) * outgate // hidden out= act_state(cellout) * outgate
// act state act_cell(D, cell_out_data, xx_data + D2);
math::vec_tanh<T>(D, cell_out_data, xx_data + D2);
blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data); blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data);
// prev // prev
...@@ -290,10 +303,8 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -290,10 +303,8 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
D4); D4);
// W_ch, W_ih, W_fh, W_oh // W_ch, W_ih, W_fh, W_oh
// actgate act_gate(D3, xx_data + D, xx_data + D);
math::vec_sigmoid<T>(D3, xx_data + D, xx_data + D); act_cand(D, xx_data, xx_data);
// ch gate
math::vec_tanh<T>(D, xx_data, xx_data);
// a = forget * prev_cell // a = forget * prev_cell
blas.VMUL(D, xx_data + D2, prev_cell_data, xx_data + D2); blas.VMUL(D, xx_data + D2, prev_cell_data, xx_data + D2);
...@@ -305,8 +316,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -305,8 +316,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
blas.VADD(D, xx_data + D, xx_data + D2, cell_out_data); blas.VADD(D, xx_data + D, xx_data + D2, cell_out_data);
// hidden out= act_state(cellout) * outgate // hidden out= act_state(cellout) * outgate
// act state act_cell(D, cell_out_data, xx_data + D2);
math::vec_tanh<T>(D, cell_out_data, xx_data + D2);
blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data); blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data);
// prev // prev
......
...@@ -45,7 +45,7 @@ def fusion_lstm( ...@@ -45,7 +45,7 @@ def fusion_lstm(
class TestLstmOp(OpTest): class TestLstmOp(OpTest):
def set_argument(self): def set_argument(self):
self.lod = [[2, 3, 2]] pass
def setUp(self): def setUp(self):
self.op_type = 'fusion_lstm' self.op_type = 'fusion_lstm'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册