提交 1997a134 编写于 作者: E eclipsess

add gru Sgemm_omp

上级 6c21a85b
...@@ -30,20 +30,34 @@ struct GRUUnitFunctor<CPU, T> { ...@@ -30,20 +30,34 @@ struct GRUUnitFunctor<CPU, T> {
const ActivationType active_gate) { const ActivationType active_gate) {
Gemm gemm; Gemm gemm;
if (value.prev_out_value) { if (value.prev_out_value) {
#ifdef _OPENMP
gemm.Sgemm_omp(batch_size, frame_size * 2, frame_size, 1,
value.prev_out_value, frame_size, value.gate_weight,
frame_size * 2, 1, value.gate_value, frame_size * 3, false,
nullptr);
#else
gemm.Sgemm(batch_size, frame_size * 2, frame_size, 1, gemm.Sgemm(batch_size, frame_size * 2, frame_size, 1,
value.prev_out_value, frame_size, value.gate_weight, value.prev_out_value, frame_size, value.gate_weight,
frame_size * 2, 1, value.gate_value, frame_size * 3, false, frame_size * 2, 1, value.gate_value, frame_size * 3, false,
nullptr); nullptr);
#endif
} }
forward_reset_output(forward::gru_resetOutput<T>(), value, frame_size, forward_reset_output(forward::gru_resetOutput<T>(), value, frame_size,
batch_size, active_gate); batch_size, active_gate);
if (value.prev_out_value) { if (value.prev_out_value) {
#ifdef _OPENMP
gemm.Sgemm_omp(batch_size, frame_size, frame_size, 1,
value.reset_output_value, frame_size, value.state_weight,
frame_size, 1, value.gate_value + frame_size * 2,
frame_size * 3, false, nullptr);
#else
gemm.Sgemm(batch_size, frame_size, frame_size, 1, gemm.Sgemm(batch_size, frame_size, frame_size, 1,
value.reset_output_value, frame_size, value.state_weight, value.reset_output_value, frame_size, value.state_weight,
frame_size, 1, value.gate_value + frame_size * 2, frame_size, 1, value.gate_value + frame_size * 2,
frame_size * 3, false, nullptr); frame_size * 3, false, nullptr);
#endif
} }
forward_final_output(forward::gru_finalOutput<T>(), value, frame_size, forward_final_output(forward::gru_finalOutput<T>(), value, frame_size,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册