diff --git a/paddle/fluid/operators/fusion_gru_op.cc b/paddle/fluid/operators/fusion_gru_op.cc index fcd551ed3b00cc7c3fbbc2c6e7f9fd071d3228a4..d1a0a05c7092235f67bf9e096684405d29a59168 100644 --- a/paddle/fluid/operators/fusion_gru_op.cc +++ b/paddle/fluid/operators/fusion_gru_op.cc @@ -266,25 +266,24 @@ class FusionGRUKernel : public framework::OpKernel { batched_input_data, D3); T* cur_batched_data = batched_input_data; + T* cur_out_data = batched_out_data; T* cur_prev_hidden_data = prev_hidden_data; for (int i = 0; i < cur_bs; ++i) { act_gate(D2, cur_batched_data, cur_batched_data); // rt = rt*ht_1 inplace result - // TODO(TJ): try to save to cur out data - // maybe get benifits avoiding cache miss in next gemm - blas.VMUL(D, cur_prev_hidden_data, cur_batched_data + D, - cur_batched_data + D); + blas.VMUL(D, cur_prev_hidden_data, cur_batched_data + D, cur_out_data); cur_batched_data += D3; cur_prev_hidden_data += D; + cur_out_data += D; } cur_batched_data = batched_input_data; + cur_out_data = batched_out_data; blas.GEMM(CblasNoTrans, CblasNoTrans, cur_bs, D, D, static_cast(1), - cur_batched_data + D, D3, wh_state_data, D, static_cast(1), + cur_out_data, D, wh_state_data, D, static_cast(1), cur_batched_data + D2, D3); - T* cur_out_data = batched_out_data; cur_prev_hidden_data = prev_hidden_data; for (int i = 0; i < cur_bs; ++i) { // ht~ = act_state(...)