提交 aa83e19e 编写于 作者: G guosheng

Remove lstm_op including in gru_op

上级 afd1f361
......@@ -14,7 +14,6 @@
#pragma once
#include "paddle/operators/lstm_op.h"
#include "paddle/operators/math/gru_compute.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/sequence2batch.h"
......@@ -25,6 +24,18 @@
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
template <typename Place, typename T>
inline void ReorderInitState(const platform::DeviceContext& ctx,
const framework::Tensor& src, const size_t* index,
framework::Tensor* dst, bool indexed_src) {
math::CopyMatrixRowsFunctor<Place, T> row_shuffle;
dst->mutable_data<T>(src.dims(), ctx.GetPlace());
row_shuffle(ctx, src, index, *dst, indexed_src);
}
template <typename Place, typename T>
class GRUKernel : public framework::OpKernel<T> {
public:
......@@ -194,16 +205,9 @@ class GRUGradKernel : public framework::OpKernel<T> {
batch_reset_hidden_prev_grad.Slice(bstart, bend);
gru_grad.resetOutputGrad = reset_hidden_prev_grad_t.data<T>();
if (n == 0) {
if (h0) {
gru_value.prevOutValue = ordered_h0.data<T>();
} else {
gru_value.prevOutValue = nullptr;
}
if (h0 && h0_grad) {
gru_grad.prevOutGrad = ordered_h0_grad.data<T>();
} else {
gru_grad.prevOutGrad = nullptr;
}
gru_value.prevOutValue = h0 ? ordered_h0.data<T>() : nullptr;
gru_grad.prevOutGrad =
h0 && h0_grad ? ordered_h0_grad.data<T>() : nullptr;
} else {
int bstart_pre = static_cast<int>(batch_starts[n - 1]);
Tensor hidden_prev_t = batch_hidden->Slice(bstart_pre, bstart);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册