提交 b103072d 编写于 作者: G guosheng

Fix data order of H0 in GRU Operator

上级 80de144b
......@@ -14,6 +14,7 @@
#pragma once
#include "paddle/operators/lstm_op.h"
#include "paddle/operators/math/gru_compute.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/sequence2batch.h"
......@@ -24,20 +25,12 @@
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T>
class GRUKernel : public framework::OpKernel<T> {
public:
void BatchCompute(const framework::ExecutionContext& context) const {
auto* input = context.Input<LoDTensor>("Input");
auto* h0 = context.Input<Tensor>("H0");
const T* h0_data = h0 ? h0->data<T>() : nullptr;
auto* weight = context.Input<Tensor>("Weight");
const T* weight_data = weight->data<T>();
auto* bias = context.Input<Tensor>("Bias");
......@@ -74,7 +67,18 @@ class GRUKernel : public framework::OpKernel<T> {
gru_value.gateWeight = const_cast<T*>(weight_data);
gru_value.stateWeight =
const_cast<T*>(weight_data + 2 * frame_size * frame_size);
gru_value.prevOutValue = const_cast<T*>(h0_data);
Tensor ordered_h0;
const size_t* order = batch_gate->lod()[2].data();
if (h0) {
// Since the batch computing for GRU reorders the input sequences
// according to their length. The initialized cell state also needs
// to reorder.
ReorderInitState<Place, T>(context.device_context(), *h0, order,
&ordered_h0, true);
gru_value.prevOutValue = ordered_h0.data<T>();
} else {
gru_value.prevOutValue = nullptr;
}
auto batch_starts = batch_gate->lod()[0];
size_t num_batch = batch_starts.size() - 1;
for (size_t n = 0; n < num_batch; n++) {
......@@ -110,7 +114,6 @@ class GRUGradKernel : public framework::OpKernel<T> {
public:
void BatchCompute(const framework::ExecutionContext& context) const {
auto* h0 = context.Input<Tensor>("H0");
const T* h0_data = h0 ? h0->data<T>() : nullptr;
auto* weight = context.Input<Tensor>("Weight");
const T* weight_data = weight->data<T>();
auto* batch_gate = context.Input<LoDTensor>("BatchGate");
......@@ -143,6 +146,16 @@ class GRUGradKernel : public framework::OpKernel<T> {
zero(context.device_context(), &batch_reset_hidden_prev_grad,
static_cast<T>(0.0));
Tensor ordered_h0, ordered_h0_grad;
const size_t* order = batch_gate->lod()[2].data();
if (h0) {
ReorderInitState<Place, T>(context.device_context(), *h0, order,
&ordered_h0, true);
}
if (h0_grad) {
ordered_h0_grad.mutable_data<T>(h0_grad->dims(), context.GetPlace());
}
bool is_reverse = context.Attr<bool>("is_reverse");
batch_hidden_grad.set_lod(batch_hidden->lod());
to_batch(context.device_context(), *hidden_grad, batch_hidden_grad, false,
......@@ -185,11 +198,13 @@ class GRUGradKernel : public framework::OpKernel<T> {
batch_reset_hidden_prev_grad.Slice(bstart, bend);
gru_grad.resetOutputGrad = reset_hidden_prev_grad_t.data<T>();
if (n == 0) {
gru_value.prevOutValue = const_cast<T*>(h0_data);
if (h0_grad) {
T* h0_grad_data = h0_grad->mutable_data<T>(context.GetPlace());
zero(context.device_context(), h0_grad, static_cast<T>(0.0));
gru_grad.prevOutGrad = h0_grad_data;
if (h0) {
gru_value.prevOutValue = ordered_h0.data<T>();
} else {
gru_value.prevOutValue = nullptr;
}
if (h0 && h0_grad) {
gru_grad.prevOutGrad = ordered_h0_grad.data<T>();
} else {
gru_grad.prevOutGrad = nullptr;
}
......@@ -220,6 +235,10 @@ class GRUGradKernel : public framework::OpKernel<T> {
auto place = context.GetEigenDevice<Place>();
d_b.device(place) = d_g.sum(Eigen::array<int, 1>({{0}}));
}
if (h0 && h0_grad) {
ReorderInitState<Place, T>(context.device_context(), ordered_h0_grad,
order, h0_grad, false);
}
}
void Compute(const framework::ExecutionContext& context) const override {
......
......@@ -6,7 +6,8 @@ from test_lstm_op import identity, sigmoid, tanh, relu
class TestGRUOp(OpTest):
batch_size = 9
lod = [[0, 2, 6, 9]]
batch_size = lod[0][-1]
frame_size = 5
activate = {
'identity': identity,
......@@ -35,7 +36,7 @@ class TestGRUOp(OpTest):
seq_starts[sorted_seqs[i]] + batch_idx)
idx_in_seq.append(idx)
idx_in_seq_list.append(idx_in_seq)
return idx_in_seq_list
return idx_in_seq_list, sorted_seqs
def gru_step(self, x, h_p, w, b):
batch_size = x.shape[0]
......@@ -66,8 +67,8 @@ class TestGRUOp(OpTest):
batch_hidden = self.outputs['BatchHidden']
hidden = self.outputs['Hidden']
idx_in_seq_list = self.idx_in_seq_list
h_p = self.inputs['H0'] if self.inputs.has_key('H0') else np.zeros(
(len(idx_in_seq_list[0]), self.frame_size))
h_p = self.inputs['H0'][self.sorted_seqs] if self.inputs.has_key(
'H0') else np.zeros((len(idx_in_seq_list[0]), self.frame_size))
num_batch = len(idx_in_seq_list)
end_idx = 0
for batch_idx in range(num_batch):
......@@ -84,8 +85,9 @@ class TestGRUOp(OpTest):
return batch_gate, batch_reset_hidden_prev, hidden
def set_data(self):
lod = [[0, 2, 6, self.batch_size]]
self.idx_in_seq_list = self.seq_to_batch(lod, self.is_reverse)
lod = self.lod
self.idx_in_seq_list, self.sorted_seqs = self.seq_to_batch(
lod, self.is_reverse)
batch_size = self.batch_size
frame_size = self.frame_size
input = np.random.rand(batch_size, frame_size * 3).astype('float64')
......@@ -146,8 +148,8 @@ class TestGRUOpReverse(TestGRUOp):
def set_confs(self):
self.is_reverse = True
self.attrs = {
'activation': 'identity',
'gate_activation': 'sigmoid',
'activation': 'tanh',
'gate_activation': 'tanh',
'is_reverse': self.is_reverse
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册