diff --git a/paddle/operators/gru_op.h b/paddle/operators/gru_op.h index ba90ec9816c40a6a49065ac6efcee6b93dffce90..b2cf358994f1bdf3597eefe67bee4e77599a9b6b 100644 --- a/paddle/operators/gru_op.h +++ b/paddle/operators/gru_op.h @@ -14,6 +14,7 @@ #pragma once +#include "paddle/operators/lstm_op.h" #include "paddle/operators/math/gru_compute.h" #include "paddle/operators/math/math_function.h" #include "paddle/operators/math/sequence2batch.h" @@ -24,20 +25,12 @@ namespace paddle { namespace operators { -using Tensor = framework::Tensor; -using LoDTensor = framework::LoDTensor; - -template -using EigenMatrix = framework::EigenMatrix; - template class GRUKernel : public framework::OpKernel { public: void BatchCompute(const framework::ExecutionContext& context) const { auto* input = context.Input("Input"); auto* h0 = context.Input("H0"); - const T* h0_data = h0 ? h0->data() : nullptr; auto* weight = context.Input("Weight"); const T* weight_data = weight->data(); auto* bias = context.Input("Bias"); @@ -74,7 +67,18 @@ class GRUKernel : public framework::OpKernel { gru_value.gateWeight = const_cast(weight_data); gru_value.stateWeight = const_cast(weight_data + 2 * frame_size * frame_size); - gru_value.prevOutValue = const_cast(h0_data); + Tensor ordered_h0; + const size_t* order = batch_gate->lod()[2].data(); + if (h0) { + // Since the batch computing for GRU reorders the input sequences + // according to their length. The initialized cell state also needs + // to reorder. + ReorderInitState(context.device_context(), *h0, order, + &ordered_h0, true); + gru_value.prevOutValue = ordered_h0.data(); + } else { + gru_value.prevOutValue = nullptr; + } auto batch_starts = batch_gate->lod()[0]; size_t num_batch = batch_starts.size() - 1; for (size_t n = 0; n < num_batch; n++) { @@ -110,7 +114,6 @@ class GRUGradKernel : public framework::OpKernel { public: void BatchCompute(const framework::ExecutionContext& context) const { auto* h0 = context.Input("H0"); - const T* h0_data = h0 ? h0->data() : nullptr; auto* weight = context.Input("Weight"); const T* weight_data = weight->data(); auto* batch_gate = context.Input("BatchGate"); @@ -143,6 +146,16 @@ class GRUGradKernel : public framework::OpKernel { zero(context.device_context(), &batch_reset_hidden_prev_grad, static_cast(0.0)); + Tensor ordered_h0, ordered_h0_grad; + const size_t* order = batch_gate->lod()[2].data(); + if (h0) { + ReorderInitState(context.device_context(), *h0, order, + &ordered_h0, true); + } + if (h0_grad) { + ordered_h0_grad.mutable_data(h0_grad->dims(), context.GetPlace()); + } + bool is_reverse = context.Attr("is_reverse"); batch_hidden_grad.set_lod(batch_hidden->lod()); to_batch(context.device_context(), *hidden_grad, batch_hidden_grad, false, @@ -185,11 +198,13 @@ class GRUGradKernel : public framework::OpKernel { batch_reset_hidden_prev_grad.Slice(bstart, bend); gru_grad.resetOutputGrad = reset_hidden_prev_grad_t.data(); if (n == 0) { - gru_value.prevOutValue = const_cast(h0_data); - if (h0_grad) { - T* h0_grad_data = h0_grad->mutable_data(context.GetPlace()); - zero(context.device_context(), h0_grad, static_cast(0.0)); - gru_grad.prevOutGrad = h0_grad_data; + if (h0) { + gru_value.prevOutValue = ordered_h0.data(); + } else { + gru_value.prevOutValue = nullptr; + } + if (h0 && h0_grad) { + gru_grad.prevOutGrad = ordered_h0_grad.data(); } else { gru_grad.prevOutGrad = nullptr; } @@ -220,6 +235,10 @@ class GRUGradKernel : public framework::OpKernel { auto place = context.GetEigenDevice(); d_b.device(place) = d_g.sum(Eigen::array({{0}})); } + if (h0 && h0_grad) { + ReorderInitState(context.device_context(), ordered_h0_grad, + order, h0_grad, false); + } } void Compute(const framework::ExecutionContext& context) const override { diff --git a/python/paddle/v2/framework/tests/test_gru_op.py b/python/paddle/v2/framework/tests/test_gru_op.py index b2474cff94c6c71cc62bc8e69a5d83e38d51c511..2bb78d10e0e08b9916e12f733b2fe4dfc8e4bae5 100644 --- a/python/paddle/v2/framework/tests/test_gru_op.py +++ b/python/paddle/v2/framework/tests/test_gru_op.py @@ -6,7 +6,8 @@ from test_lstm_op import identity, sigmoid, tanh, relu class TestGRUOp(OpTest): - batch_size = 9 + lod = [[0, 2, 6, 9]] + batch_size = lod[0][-1] frame_size = 5 activate = { 'identity': identity, @@ -35,7 +36,7 @@ class TestGRUOp(OpTest): seq_starts[sorted_seqs[i]] + batch_idx) idx_in_seq.append(idx) idx_in_seq_list.append(idx_in_seq) - return idx_in_seq_list + return idx_in_seq_list, sorted_seqs def gru_step(self, x, h_p, w, b): batch_size = x.shape[0] @@ -66,8 +67,8 @@ class TestGRUOp(OpTest): batch_hidden = self.outputs['BatchHidden'] hidden = self.outputs['Hidden'] idx_in_seq_list = self.idx_in_seq_list - h_p = self.inputs['H0'] if self.inputs.has_key('H0') else np.zeros( - (len(idx_in_seq_list[0]), self.frame_size)) + h_p = self.inputs['H0'][self.sorted_seqs] if self.inputs.has_key( + 'H0') else np.zeros((len(idx_in_seq_list[0]), self.frame_size)) num_batch = len(idx_in_seq_list) end_idx = 0 for batch_idx in range(num_batch): @@ -84,8 +85,9 @@ class TestGRUOp(OpTest): return batch_gate, batch_reset_hidden_prev, hidden def set_data(self): - lod = [[0, 2, 6, self.batch_size]] - self.idx_in_seq_list = self.seq_to_batch(lod, self.is_reverse) + lod = self.lod + self.idx_in_seq_list, self.sorted_seqs = self.seq_to_batch( + lod, self.is_reverse) batch_size = self.batch_size frame_size = self.frame_size input = np.random.rand(batch_size, frame_size * 3).astype('float64') @@ -146,8 +148,8 @@ class TestGRUOpReverse(TestGRUOp): def set_confs(self): self.is_reverse = True self.attrs = { - 'activation': 'identity', - 'gate_activation': 'sigmoid', + 'activation': 'tanh', + 'gate_activation': 'tanh', 'is_reverse': self.is_reverse }