From d9cc6b18662295383f925e12b6a5e0cf5dabd14a Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Fri, 3 Aug 2018 13:31:53 +0800 Subject: [PATCH] replace gru compute with details --- paddle/fluid/operators/gru_op.h | 28 +++++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/gru_op.h b/paddle/fluid/operators/gru_op.h index 3b0d93e54b..4e534789ce 100644 --- a/paddle/fluid/operators/gru_op.h +++ b/paddle/fluid/operators/gru_op.h @@ -16,7 +16,10 @@ limitations under the License. */ #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/detail/activation_functions.h" +#include "paddle/fluid/operators/math/detail/gru_cpu_kernel.h" +#include "paddle/fluid/operators/math/detail/gru_kernel.h" #include "paddle/fluid/operators/math/gru_compute.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/sequence2batch.h" @@ -94,6 +97,7 @@ class GRUKernel : public framework::OpKernel { context.Attr("activation")); auto active_gate = math::detail::GetActivationType( context.Attr("gate_activation")); + auto blas = math::GetBlas(dev_ctx); for (size_t n = 0; n < num_batch; n++) { int bstart = static_cast(batch_starts[n]); int bend = static_cast(batch_starts[n + 1]); @@ -105,9 +109,27 @@ class GRUKernel : public framework::OpKernel { gru_value.output_value = hidden_t.data(); gru_value.gate_value = gate_t.data(); gru_value.reset_output_value = reset_hidden_prev_t.data(); - math::GRUUnitFunctor::compute( - dev_ctx, gru_value, frame_size, cur_batch_size, active_node, - active_gate); + if (gru_value.prev_out_value) { + blas.GEMM(false, false, cur_batch_size, frame_size * 2, frame_size, 1, + gru_value.prev_out_value, frame_size, gru_value.gate_weight, + frame_size * 2, 1, gru_value.gate_value, frame_size * 3); + } + + math::detail::forward_reset_output( + math::detail::forward::gru_resetOutput(), gru_value, frame_size, + cur_batch_size, active_gate); + + if (gru_value.prev_out_value) { + blas.GEMM(false, false, cur_batch_size, frame_size, frame_size, 1, + gru_value.reset_output_value, frame_size, + gru_value.state_weight, frame_size, 1, + gru_value.gate_value + frame_size * 2, frame_size * 3); + } + + math::detail::forward_final_output( + math::detail::forward::gru_finalOutput(), gru_value, frame_size, + cur_batch_size, active_node); + gru_value.prev_out_value = gru_value.output_value; } -- GitLab