From aa83e19e24d2318381bd4859588f15d43336f041 Mon Sep 17 00:00:00 2001 From: guosheng Date: Fri, 17 Nov 2017 14:18:34 +0800 Subject: [PATCH] Remove lstm_op including in gru_op --- paddle/operators/gru_op.h | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/paddle/operators/gru_op.h b/paddle/operators/gru_op.h index a7264507b..1b18368e0 100644 --- a/paddle/operators/gru_op.h +++ b/paddle/operators/gru_op.h @@ -14,7 +14,6 @@ #pragma once -#include "paddle/operators/lstm_op.h" #include "paddle/operators/math/gru_compute.h" #include "paddle/operators/math/math_function.h" #include "paddle/operators/math/sequence2batch.h" @@ -25,6 +24,18 @@ namespace paddle { namespace operators { +using LoDTensor = framework::LoDTensor; +using Tensor = framework::Tensor; + +template +inline void ReorderInitState(const platform::DeviceContext& ctx, + const framework::Tensor& src, const size_t* index, + framework::Tensor* dst, bool indexed_src) { + math::CopyMatrixRowsFunctor row_shuffle; + dst->mutable_data(src.dims(), ctx.GetPlace()); + row_shuffle(ctx, src, index, *dst, indexed_src); +} + template class GRUKernel : public framework::OpKernel { public: @@ -194,16 +205,9 @@ class GRUGradKernel : public framework::OpKernel { batch_reset_hidden_prev_grad.Slice(bstart, bend); gru_grad.resetOutputGrad = reset_hidden_prev_grad_t.data(); if (n == 0) { - if (h0) { - gru_value.prevOutValue = ordered_h0.data(); - } else { - gru_value.prevOutValue = nullptr; - } - if (h0 && h0_grad) { - gru_grad.prevOutGrad = ordered_h0_grad.data(); - } else { - gru_grad.prevOutGrad = nullptr; - } + gru_value.prevOutValue = h0 ? ordered_h0.data() : nullptr; + gru_grad.prevOutGrad = + h0 && h0_grad ? ordered_h0_grad.data() : nullptr; } else { int bstart_pre = static_cast(batch_starts[n - 1]); Tensor hidden_prev_t = batch_hidden->Slice(bstart_pre, bstart); -- GitLab