提交 ba943d38 编写于 作者: T tensor-tang

make runtime avx act

上级 3462c299
......@@ -299,11 +299,21 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(c0->dims()[0], N, "C0 dims should be %d x %d.", N, D);
fc_out->Resize({max_seq_len, 1});
math::VecActivations<T> act_functor;
std::function<void(const int, const T *, T *)> act_gate, act_cell, act_cand;
act_gate = act_functor(ctx.Attr<std::string>("gate_activation"));
act_cell = act_functor(ctx.Attr<std::string>("cell_activation"));
act_cand = act_functor(ctx.Attr<std::string>("candidate_activation"));
auto& act_gate_str = ctx.Attr<std::string>("gate_activation");
auto& act_cell_str = ctx.Attr<std::string>("cell_activation");
auto& act_cand_str = ctx.Attr<std::string>("candidate_activation");
if (platform::jit::MayIUse(platform::jit::avx)) {
math::VecActivations<T, platform::jit::avx> act_functor;
act_gate = act_functor(act_gate_str);
act_cell = act_functor(act_cell_str);
act_cand = act_functor(act_cand_str);
} else {
math::VecActivations<T, platform::jit::isa_any> act_functor;
act_gate = act_functor(act_gate_str);
act_cell = act_functor(act_cell_str);
act_cand = act_functor(act_cand_str);
}
const T* x_data = x->data<T>();
const T* h0_data = h0 ? h0->data<T>() : NULL;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册