diff --git a/paddle/operators/gru_op.h b/paddle/operators/gru_op.h index a7264507bbd127f3f611cf6c8ef95c0c92fde333..1b18368e0e16365682520b62a7f6adab0cbb527f 100644 --- a/paddle/operators/gru_op.h +++ b/paddle/operators/gru_op.h @@ -14,7 +14,6 @@ #pragma once -#include "paddle/operators/lstm_op.h" #include "paddle/operators/math/gru_compute.h" #include "paddle/operators/math/math_function.h" #include "paddle/operators/math/sequence2batch.h" @@ -25,6 +24,18 @@ namespace paddle { namespace operators { +using LoDTensor = framework::LoDTensor; +using Tensor = framework::Tensor; + +template +inline void ReorderInitState(const platform::DeviceContext& ctx, + const framework::Tensor& src, const size_t* index, + framework::Tensor* dst, bool indexed_src) { + math::CopyMatrixRowsFunctor row_shuffle; + dst->mutable_data(src.dims(), ctx.GetPlace()); + row_shuffle(ctx, src, index, *dst, indexed_src); +} + template class GRUKernel : public framework::OpKernel { public: @@ -194,16 +205,9 @@ class GRUGradKernel : public framework::OpKernel { batch_reset_hidden_prev_grad.Slice(bstart, bend); gru_grad.resetOutputGrad = reset_hidden_prev_grad_t.data(); if (n == 0) { - if (h0) { - gru_value.prevOutValue = ordered_h0.data(); - } else { - gru_value.prevOutValue = nullptr; - } - if (h0 && h0_grad) { - gru_grad.prevOutGrad = ordered_h0_grad.data(); - } else { - gru_grad.prevOutGrad = nullptr; - } + gru_value.prevOutValue = h0 ? ordered_h0.data() : nullptr; + gru_grad.prevOutGrad = + h0 && h0_grad ? ordered_h0_grad.data() : nullptr; } else { int bstart_pre = static_cast(batch_starts[n - 1]); Tensor hidden_prev_t = batch_hidden->Slice(bstart_pre, bstart);