diff --git a/src/operators/math/gru_compute.cpp b/src/operators/math/gru_compute.cpp index 8ebf92059b5f5205b3169a6992039d3f050b3b4b..9e77f572c53bc2ba9be57f5edbd2b4bf85f5305e 100644 --- a/src/operators/math/gru_compute.cpp +++ b/src/operators/math/gru_compute.cpp @@ -30,20 +30,34 @@ struct GRUUnitFunctor { const ActivationType active_gate) { Gemm gemm; if (value.prev_out_value) { +#ifdef _OPENMP + gemm.Sgemm_omp(batch_size, frame_size * 2, frame_size, 1, + value.prev_out_value, frame_size, value.gate_weight, + frame_size * 2, 1, value.gate_value, frame_size * 3, false, + nullptr); +#else gemm.Sgemm(batch_size, frame_size * 2, frame_size, 1, value.prev_out_value, frame_size, value.gate_weight, frame_size * 2, 1, value.gate_value, frame_size * 3, false, nullptr); +#endif } forward_reset_output(forward::gru_resetOutput(), value, frame_size, batch_size, active_gate); if (value.prev_out_value) { +#ifdef _OPENMP + gemm.Sgemm_omp(batch_size, frame_size, frame_size, 1, + value.reset_output_value, frame_size, value.state_weight, + frame_size, 1, value.gate_value + frame_size * 2, + frame_size * 3, false, nullptr); +#else gemm.Sgemm(batch_size, frame_size, frame_size, 1, value.reset_output_value, frame_size, value.state_weight, frame_size, 1, value.gate_value + frame_size * 2, frame_size * 3, false, nullptr); +#endif } forward_final_output(forward::gru_finalOutput(), value, frame_size,