diff --git a/paddle/operators/gru_op.h b/paddle/operators/gru_op.h index 55e9cc4a98bd6d36ce5d6bb4116039d0ec18b485..1b18368e0e16365682520b62a7f6adab0cbb527f 100644 --- a/paddle/operators/gru_op.h +++ b/paddle/operators/gru_op.h @@ -24,8 +24,17 @@ namespace paddle { namespace operators { -using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor; +using Tensor = framework::Tensor; + +template +inline void ReorderInitState(const platform::DeviceContext& ctx, + const framework::Tensor& src, const size_t* index, + framework::Tensor* dst, bool indexed_src) { + math::CopyMatrixRowsFunctor row_shuffle; + dst->mutable_data(src.dims(), ctx.GetPlace()); + row_shuffle(ctx, src, index, *dst, indexed_src); +} template class GRUKernel : public framework::OpKernel { @@ -33,7 +42,6 @@ class GRUKernel : public framework::OpKernel { void BatchCompute(const framework::ExecutionContext& context) const { auto* input = context.Input("Input"); auto* h0 = context.Input("H0"); - const T* h0_data = h0 ? h0->data() : nullptr; auto* weight = context.Input("Weight"); const T* weight_data = weight->data(); auto* bias = context.Input("Bias"); @@ -66,7 +74,18 @@ class GRUKernel : public framework::OpKernel { gru_value.gateWeight = const_cast(weight_data); gru_value.stateWeight = const_cast(weight_data + 2 * frame_size * frame_size); - gru_value.prevOutValue = const_cast(h0_data); + Tensor ordered_h0; + const size_t* order = batch_gate->lod()[2].data(); + if (h0) { + // Since the batch computing for GRU reorders the input sequences + // according to their length. The initialized cell state also needs + // to reorder. + ReorderInitState(context.device_context(), *h0, order, + &ordered_h0, true); + gru_value.prevOutValue = ordered_h0.data(); + } else { + gru_value.prevOutValue = nullptr; + } auto batch_starts = batch_gate->lod()[0]; size_t num_batch = batch_starts.size() - 1; for (size_t n = 0; n < num_batch; n++) { @@ -102,7 +121,6 @@ class GRUGradKernel : public framework::OpKernel { public: void BatchCompute(const framework::ExecutionContext& context) const { auto* h0 = context.Input("H0"); - const T* h0_data = h0 ? h0->data() : nullptr; auto* weight = context.Input("Weight"); const T* weight_data = weight->data(); auto* batch_gate = context.Input("BatchGate"); @@ -135,6 +153,17 @@ class GRUGradKernel : public framework::OpKernel { zero(dev_ctx, &batch_gate_grad, static_cast(0.0)); zero(dev_ctx, &batch_reset_hidden_prev_grad, static_cast(0.0)); + Tensor ordered_h0, ordered_h0_grad; + const size_t* order = batch_gate->lod()[2].data(); + if (h0) { + ReorderInitState(context.device_context(), *h0, order, + &ordered_h0, true); + } + if (h0_grad) { + ordered_h0_grad.mutable_data(h0_grad->dims(), context.GetPlace()); + zero(context.device_context(), &ordered_h0_grad, static_cast(0.0)); + } + bool is_reverse = context.Attr("is_reverse"); batch_hidden_grad.set_lod(batch_hidden->lod()); to_batch(dev_ctx, *hidden_grad, batch_hidden_grad, false, is_reverse); @@ -176,14 +205,9 @@ class GRUGradKernel : public framework::OpKernel { batch_reset_hidden_prev_grad.Slice(bstart, bend); gru_grad.resetOutputGrad = reset_hidden_prev_grad_t.data(); if (n == 0) { - gru_value.prevOutValue = const_cast(h0_data); - if (h0_grad) { - T* h0_grad_data = h0_grad->mutable_data(context.GetPlace()); - zero(dev_ctx, h0_grad, static_cast(0.0)); - gru_grad.prevOutGrad = h0_grad_data; - } else { - gru_grad.prevOutGrad = nullptr; - } + gru_value.prevOutValue = h0 ? ordered_h0.data() : nullptr; + gru_grad.prevOutGrad = + h0 && h0_grad ? ordered_h0_grad.data() : nullptr; } else { int bstart_pre = static_cast(batch_starts[n - 1]); Tensor hidden_prev_t = batch_hidden->Slice(bstart_pre, bstart); @@ -208,6 +232,10 @@ class GRUGradKernel : public framework::OpKernel { math::ColwiseSum col_sum; col_sum(dev_ctx, batch_gate_grad, bias_grad); } + if (h0 && h0_grad) { + ReorderInitState(context.device_context(), ordered_h0_grad, + order, h0_grad, false); + } } void Compute(const framework::ExecutionContext& context) const override { diff --git a/python/paddle/v2/fluid/tests/test_gru_op.py b/python/paddle/v2/fluid/tests/test_gru_op.py index b2474cff94c6c71cc62bc8e69a5d83e38d51c511..fa2c5a53ec4a01b6545e25f773c11277a4d24706 100644 --- a/python/paddle/v2/fluid/tests/test_gru_op.py +++ b/python/paddle/v2/fluid/tests/test_gru_op.py @@ -6,7 +6,8 @@ from test_lstm_op import identity, sigmoid, tanh, relu class TestGRUOp(OpTest): - batch_size = 9 + lod = [[0, 2, 6, 9]] + batch_size = lod[0][-1] frame_size = 5 activate = { 'identity': identity, @@ -35,7 +36,7 @@ class TestGRUOp(OpTest): seq_starts[sorted_seqs[i]] + batch_idx) idx_in_seq.append(idx) idx_in_seq_list.append(idx_in_seq) - return idx_in_seq_list + return idx_in_seq_list, sorted_seqs def gru_step(self, x, h_p, w, b): batch_size = x.shape[0] @@ -66,8 +67,8 @@ class TestGRUOp(OpTest): batch_hidden = self.outputs['BatchHidden'] hidden = self.outputs['Hidden'] idx_in_seq_list = self.idx_in_seq_list - h_p = self.inputs['H0'] if self.inputs.has_key('H0') else np.zeros( - (len(idx_in_seq_list[0]), self.frame_size)) + h_p = self.inputs['H0'][self.sorted_seqs] if self.inputs.has_key( + 'H0') else np.zeros((len(idx_in_seq_list[0]), self.frame_size)) num_batch = len(idx_in_seq_list) end_idx = 0 for batch_idx in range(num_batch): @@ -84,8 +85,9 @@ class TestGRUOp(OpTest): return batch_gate, batch_reset_hidden_prev, hidden def set_data(self): - lod = [[0, 2, 6, self.batch_size]] - self.idx_in_seq_list = self.seq_to_batch(lod, self.is_reverse) + lod = self.lod + self.idx_in_seq_list, self.sorted_seqs = self.seq_to_batch( + lod, self.is_reverse) batch_size = self.batch_size frame_size = self.frame_size input = np.random.rand(batch_size, frame_size * 3).astype('float64') @@ -146,7 +148,7 @@ class TestGRUOpReverse(TestGRUOp): def set_confs(self): self.is_reverse = True self.attrs = { - 'activation': 'identity', + 'activation': 'tanh', 'gate_activation': 'sigmoid', 'is_reverse': self.is_reverse }