提交 aa83e19e 编写于 作者: G guosheng

Remove lstm_op including in gru_op

上级 afd1f361
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
#pragma once #pragma once
#include "paddle/operators/lstm_op.h"
#include "paddle/operators/math/gru_compute.h" #include "paddle/operators/math/gru_compute.h"
#include "paddle/operators/math/math_function.h" #include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/sequence2batch.h" #include "paddle/operators/math/sequence2batch.h"
...@@ -25,6 +24,18 @@ ...@@ -25,6 +24,18 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
template <typename Place, typename T>
inline void ReorderInitState(const platform::DeviceContext& ctx,
const framework::Tensor& src, const size_t* index,
framework::Tensor* dst, bool indexed_src) {
math::CopyMatrixRowsFunctor<Place, T> row_shuffle;
dst->mutable_data<T>(src.dims(), ctx.GetPlace());
row_shuffle(ctx, src, index, *dst, indexed_src);
}
template <typename Place, typename T> template <typename Place, typename T>
class GRUKernel : public framework::OpKernel<T> { class GRUKernel : public framework::OpKernel<T> {
public: public:
...@@ -194,16 +205,9 @@ class GRUGradKernel : public framework::OpKernel<T> { ...@@ -194,16 +205,9 @@ class GRUGradKernel : public framework::OpKernel<T> {
batch_reset_hidden_prev_grad.Slice(bstart, bend); batch_reset_hidden_prev_grad.Slice(bstart, bend);
gru_grad.resetOutputGrad = reset_hidden_prev_grad_t.data<T>(); gru_grad.resetOutputGrad = reset_hidden_prev_grad_t.data<T>();
if (n == 0) { if (n == 0) {
if (h0) { gru_value.prevOutValue = h0 ? ordered_h0.data<T>() : nullptr;
gru_value.prevOutValue = ordered_h0.data<T>(); gru_grad.prevOutGrad =
} else { h0 && h0_grad ? ordered_h0_grad.data<T>() : nullptr;
gru_value.prevOutValue = nullptr;
}
if (h0 && h0_grad) {
gru_grad.prevOutGrad = ordered_h0_grad.data<T>();
} else {
gru_grad.prevOutGrad = nullptr;
}
} else { } else {
int bstart_pre = static_cast<int>(batch_starts[n - 1]); int bstart_pre = static_cast<int>(batch_starts[n - 1]);
Tensor hidden_prev_t = batch_hidden->Slice(bstart_pre, bstart); Tensor hidden_prev_t = batch_hidden->Slice(bstart_pre, bstart);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册