diff --git a/paddle/fluid/operators/gru_op.h b/paddle/fluid/operators/gru_op.h index 4e534789ce6aed6d5be94ec36deeedec21e69cd9..a9450337e74fe89e16524ff03346bebe43bb2bd3 100644 --- a/paddle/fluid/operators/gru_op.h +++ b/paddle/fluid/operators/gru_op.h @@ -98,6 +98,23 @@ class GRUKernel : public framework::OpKernel { auto active_gate = math::detail::GetActivationType( context.Attr("gate_activation")); auto blas = math::GetBlas(dev_ctx); + + // TODO(TJ): make a class, make one pack + T* packed_gate = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/, + frame_size * 2 /*width of weight*/, + frame_size /*height of height*/); + PADDLE_ENFORCE(packed_gate); + blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size * 2, + frame_size, T(1.0), gru_value.gate_weight, frame_size * 2, + packed_gate); + T* packed_state = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/, + frame_size /*width of weight*/, + frame_size /*height of height*/); + PADDLE_ENFORCE(packed_state); + blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size, + frame_size, T(1.0), gru_value.state_weight, frame_size, + packed_state); + for (size_t n = 0; n < num_batch; n++) { int bstart = static_cast(batch_starts[n]); int bend = static_cast(batch_starts[n + 1]); @@ -110,9 +127,10 @@ class GRUKernel : public framework::OpKernel { gru_value.gate_value = gate_t.data(); gru_value.reset_output_value = reset_hidden_prev_t.data(); if (gru_value.prev_out_value) { - blas.GEMM(false, false, cur_batch_size, frame_size * 2, frame_size, 1, - gru_value.prev_out_value, frame_size, gru_value.gate_weight, - frame_size * 2, 1, gru_value.gate_value, frame_size * 3); + blas.GEMM_COMPUTE(CblasNoTrans, CblasPacked, cur_batch_size, + frame_size * 2, frame_size, gru_value.prev_out_value, + frame_size, packed_gate, frame_size * 2, T(1), + gru_value.gate_value, frame_size * 3); } math::detail::forward_reset_output( @@ -120,10 +138,10 @@ class GRUKernel : public framework::OpKernel { cur_batch_size, active_gate); if (gru_value.prev_out_value) { - blas.GEMM(false, false, cur_batch_size, frame_size, frame_size, 1, - gru_value.reset_output_value, frame_size, - gru_value.state_weight, frame_size, 1, - gru_value.gate_value + frame_size * 2, frame_size * 3); + blas.GEMM_COMPUTE( + CblasNoTrans, CblasPacked, cur_batch_size, frame_size, frame_size, + gru_value.reset_output_value, frame_size, packed_state, frame_size, + T(1), gru_value.gate_value + frame_size * 2, frame_size * 3); } math::detail::forward_final_output( @@ -132,6 +150,8 @@ class GRUKernel : public framework::OpKernel { gru_value.prev_out_value = gru_value.output_value; } + blas.GEMM_FREE(packed_gate); + blas.GEMM_FREE(packed_state); math::Batch2LoDTensorFunctor to_seq; batch_hidden->set_lod(batch_gate->lod()); diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h index 2470df9d78339eec7d000f9c6c5c67704ace7a33..485e96227e4fe9bc9380b83e13aab1ac044dea89 100644 --- a/paddle/fluid/operators/math/blas.h +++ b/paddle/fluid/operators/math/blas.h @@ -165,7 +165,7 @@ class BlasT : private Blas { template T* GEMM_ALLOC(ARGS... args) const { - Base()->template GEMM_ALLOC(args...); + return Base()->template GEMM_ALLOC(args...); } template