From 038c16eed296cabab148fc067b0c4f383dfddc76 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Thu, 30 Aug 2018 22:59:59 +0800 Subject: [PATCH] save intermediate data to out buffer --- paddle/fluid/operators/fusion_gru_op.cc | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/fusion_gru_op.cc b/paddle/fluid/operators/fusion_gru_op.cc index fcd551ed3b0..d1a0a05c709 100644 --- a/paddle/fluid/operators/fusion_gru_op.cc +++ b/paddle/fluid/operators/fusion_gru_op.cc @@ -266,25 +266,24 @@ class FusionGRUKernel : public framework::OpKernel { batched_input_data, D3); T* cur_batched_data = batched_input_data; + T* cur_out_data = batched_out_data; T* cur_prev_hidden_data = prev_hidden_data; for (int i = 0; i < cur_bs; ++i) { act_gate(D2, cur_batched_data, cur_batched_data); // rt = rt*ht_1 inplace result - // TODO(TJ): try to save to cur out data - // maybe get benifits avoiding cache miss in next gemm - blas.VMUL(D, cur_prev_hidden_data, cur_batched_data + D, - cur_batched_data + D); + blas.VMUL(D, cur_prev_hidden_data, cur_batched_data + D, cur_out_data); cur_batched_data += D3; cur_prev_hidden_data += D; + cur_out_data += D; } cur_batched_data = batched_input_data; + cur_out_data = batched_out_data; blas.GEMM(CblasNoTrans, CblasNoTrans, cur_bs, D, D, static_cast(1), - cur_batched_data + D, D3, wh_state_data, D, static_cast(1), + cur_out_data, D, wh_state_data, D, static_cast(1), cur_batched_data + D2, D3); - T* cur_out_data = batched_out_data; cur_prev_hidden_data = prev_hidden_data; for (int i = 0; i < cur_bs; ++i) { // ht~ = act_state(...) -- GitLab