提交 4b28fab8 编写于 作者: T tensor-tang

enable more acts

上级 607c4195
......@@ -230,6 +230,22 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
auto* cell_out = ctx.Output<LoDTensor>("Cell");
std::function<void(const int, const T *, T *)> act_gate, act_cell, act_cand;
auto& act_gate_str = ctx.Attr<std::string>("gate_activation");
auto& act_cell_str = ctx.Attr<std::string>("cell_activation");
auto& act_cand_str = ctx.Attr<std::string>("candidate_activation");
if (platform::jit::MayIUse(platform::jit::avx)) {
math::VecActivations<T, platform::jit::avx> act_functor;
act_gate = act_functor(act_gate_str);
act_cell = act_functor(act_cell_str);
act_cand = act_functor(act_cand_str);
} else {
math::VecActivations<T, platform::jit::isa_any> act_functor;
act_gate = act_functor(act_gate_str);
act_cell = act_functor(act_cell_str);
act_cand = act_functor(act_cand_str);
}
auto x_lod = x->lod();
auto x_dims = x->dims(); // T x M
auto wh_dims = wh->dims(); // D x 4D
......@@ -263,15 +279,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
prev_cell_data = c0_data + i * D;
} else {
// W_ch, W_ih, W_fh, W_oh
// actgate
math::vec_sigmoid<T>(D3, xx_data + D, xx_data + D);
// ch gate
math::vec_tanh<T>(D, xx_data, xx_data);
act_gate(D3, xx_data + D, xx_data + D);
act_cand(D, xx_data, xx_data);
// cell out= input*tilde
blas.VMUL(D, xx_data, xx_data + D, cell_out_data);
// hidden out= act_state(cellout) * outgate
// act state
math::vec_tanh<T>(D, cell_out_data, xx_data + D2);
act_cell(D, cell_out_data, xx_data + D2);
blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data);
// prev
......@@ -290,10 +303,8 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
D4);
// W_ch, W_ih, W_fh, W_oh
// actgate
math::vec_sigmoid<T>(D3, xx_data + D, xx_data + D);
// ch gate
math::vec_tanh<T>(D, xx_data, xx_data);
act_gate(D3, xx_data + D, xx_data + D);
act_cand(D, xx_data, xx_data);
// a = forget * prev_cell
blas.VMUL(D, xx_data + D2, prev_cell_data, xx_data + D2);
......@@ -305,8 +316,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
blas.VADD(D, xx_data + D, xx_data + D2, cell_out_data);
// hidden out= act_state(cellout) * outgate
// act state
math::vec_tanh<T>(D, cell_out_data, xx_data + D2);
act_cell(D, cell_out_data, xx_data + D2);
blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data);
// prev
......
......@@ -45,7 +45,7 @@ def fusion_lstm(
class TestLstmOp(OpTest):
def set_argument(self):
self.lod = [[2, 3, 2]]
pass
def setUp(self):
self.op_type = 'fusion_lstm'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册