提交 b103072d 编写于 作者: G guosheng

Fix data order of H0 in GRU Operator

上级 80de144b
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include "paddle/operators/lstm_op.h"
#include "paddle/operators/math/gru_compute.h" #include "paddle/operators/math/gru_compute.h"
#include "paddle/operators/math/math_function.h" #include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/sequence2batch.h" #include "paddle/operators/math/sequence2batch.h"
...@@ -24,20 +25,12 @@ ...@@ -24,20 +25,12 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T> template <typename Place, typename T>
class GRUKernel : public framework::OpKernel<T> { class GRUKernel : public framework::OpKernel<T> {
public: public:
void BatchCompute(const framework::ExecutionContext& context) const { void BatchCompute(const framework::ExecutionContext& context) const {
auto* input = context.Input<LoDTensor>("Input"); auto* input = context.Input<LoDTensor>("Input");
auto* h0 = context.Input<Tensor>("H0"); auto* h0 = context.Input<Tensor>("H0");
const T* h0_data = h0 ? h0->data<T>() : nullptr;
auto* weight = context.Input<Tensor>("Weight"); auto* weight = context.Input<Tensor>("Weight");
const T* weight_data = weight->data<T>(); const T* weight_data = weight->data<T>();
auto* bias = context.Input<Tensor>("Bias"); auto* bias = context.Input<Tensor>("Bias");
...@@ -74,7 +67,18 @@ class GRUKernel : public framework::OpKernel<T> { ...@@ -74,7 +67,18 @@ class GRUKernel : public framework::OpKernel<T> {
gru_value.gateWeight = const_cast<T*>(weight_data); gru_value.gateWeight = const_cast<T*>(weight_data);
gru_value.stateWeight = gru_value.stateWeight =
const_cast<T*>(weight_data + 2 * frame_size * frame_size); const_cast<T*>(weight_data + 2 * frame_size * frame_size);
gru_value.prevOutValue = const_cast<T*>(h0_data); Tensor ordered_h0;
const size_t* order = batch_gate->lod()[2].data();
if (h0) {
// Since the batch computing for GRU reorders the input sequences
// according to their length. The initialized cell state also needs
// to reorder.
ReorderInitState<Place, T>(context.device_context(), *h0, order,
&ordered_h0, true);
gru_value.prevOutValue = ordered_h0.data<T>();
} else {
gru_value.prevOutValue = nullptr;
}
auto batch_starts = batch_gate->lod()[0]; auto batch_starts = batch_gate->lod()[0];
size_t num_batch = batch_starts.size() - 1; size_t num_batch = batch_starts.size() - 1;
for (size_t n = 0; n < num_batch; n++) { for (size_t n = 0; n < num_batch; n++) {
...@@ -110,7 +114,6 @@ class GRUGradKernel : public framework::OpKernel<T> { ...@@ -110,7 +114,6 @@ class GRUGradKernel : public framework::OpKernel<T> {
public: public:
void BatchCompute(const framework::ExecutionContext& context) const { void BatchCompute(const framework::ExecutionContext& context) const {
auto* h0 = context.Input<Tensor>("H0"); auto* h0 = context.Input<Tensor>("H0");
const T* h0_data = h0 ? h0->data<T>() : nullptr;
auto* weight = context.Input<Tensor>("Weight"); auto* weight = context.Input<Tensor>("Weight");
const T* weight_data = weight->data<T>(); const T* weight_data = weight->data<T>();
auto* batch_gate = context.Input<LoDTensor>("BatchGate"); auto* batch_gate = context.Input<LoDTensor>("BatchGate");
...@@ -143,6 +146,16 @@ class GRUGradKernel : public framework::OpKernel<T> { ...@@ -143,6 +146,16 @@ class GRUGradKernel : public framework::OpKernel<T> {
zero(context.device_context(), &batch_reset_hidden_prev_grad, zero(context.device_context(), &batch_reset_hidden_prev_grad,
static_cast<T>(0.0)); static_cast<T>(0.0));
Tensor ordered_h0, ordered_h0_grad;
const size_t* order = batch_gate->lod()[2].data();
if (h0) {
ReorderInitState<Place, T>(context.device_context(), *h0, order,
&ordered_h0, true);
}
if (h0_grad) {
ordered_h0_grad.mutable_data<T>(h0_grad->dims(), context.GetPlace());
}
bool is_reverse = context.Attr<bool>("is_reverse"); bool is_reverse = context.Attr<bool>("is_reverse");
batch_hidden_grad.set_lod(batch_hidden->lod()); batch_hidden_grad.set_lod(batch_hidden->lod());
to_batch(context.device_context(), *hidden_grad, batch_hidden_grad, false, to_batch(context.device_context(), *hidden_grad, batch_hidden_grad, false,
...@@ -185,11 +198,13 @@ class GRUGradKernel : public framework::OpKernel<T> { ...@@ -185,11 +198,13 @@ class GRUGradKernel : public framework::OpKernel<T> {
batch_reset_hidden_prev_grad.Slice(bstart, bend); batch_reset_hidden_prev_grad.Slice(bstart, bend);
gru_grad.resetOutputGrad = reset_hidden_prev_grad_t.data<T>(); gru_grad.resetOutputGrad = reset_hidden_prev_grad_t.data<T>();
if (n == 0) { if (n == 0) {
gru_value.prevOutValue = const_cast<T*>(h0_data); if (h0) {
if (h0_grad) { gru_value.prevOutValue = ordered_h0.data<T>();
T* h0_grad_data = h0_grad->mutable_data<T>(context.GetPlace()); } else {
zero(context.device_context(), h0_grad, static_cast<T>(0.0)); gru_value.prevOutValue = nullptr;
gru_grad.prevOutGrad = h0_grad_data; }
if (h0 && h0_grad) {
gru_grad.prevOutGrad = ordered_h0_grad.data<T>();
} else { } else {
gru_grad.prevOutGrad = nullptr; gru_grad.prevOutGrad = nullptr;
} }
...@@ -220,6 +235,10 @@ class GRUGradKernel : public framework::OpKernel<T> { ...@@ -220,6 +235,10 @@ class GRUGradKernel : public framework::OpKernel<T> {
auto place = context.GetEigenDevice<Place>(); auto place = context.GetEigenDevice<Place>();
d_b.device(place) = d_g.sum(Eigen::array<int, 1>({{0}})); d_b.device(place) = d_g.sum(Eigen::array<int, 1>({{0}}));
} }
if (h0 && h0_grad) {
ReorderInitState<Place, T>(context.device_context(), ordered_h0_grad,
order, h0_grad, false);
}
} }
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
......
...@@ -6,7 +6,8 @@ from test_lstm_op import identity, sigmoid, tanh, relu ...@@ -6,7 +6,8 @@ from test_lstm_op import identity, sigmoid, tanh, relu
class TestGRUOp(OpTest): class TestGRUOp(OpTest):
batch_size = 9 lod = [[0, 2, 6, 9]]
batch_size = lod[0][-1]
frame_size = 5 frame_size = 5
activate = { activate = {
'identity': identity, 'identity': identity,
...@@ -35,7 +36,7 @@ class TestGRUOp(OpTest): ...@@ -35,7 +36,7 @@ class TestGRUOp(OpTest):
seq_starts[sorted_seqs[i]] + batch_idx) seq_starts[sorted_seqs[i]] + batch_idx)
idx_in_seq.append(idx) idx_in_seq.append(idx)
idx_in_seq_list.append(idx_in_seq) idx_in_seq_list.append(idx_in_seq)
return idx_in_seq_list return idx_in_seq_list, sorted_seqs
def gru_step(self, x, h_p, w, b): def gru_step(self, x, h_p, w, b):
batch_size = x.shape[0] batch_size = x.shape[0]
...@@ -66,8 +67,8 @@ class TestGRUOp(OpTest): ...@@ -66,8 +67,8 @@ class TestGRUOp(OpTest):
batch_hidden = self.outputs['BatchHidden'] batch_hidden = self.outputs['BatchHidden']
hidden = self.outputs['Hidden'] hidden = self.outputs['Hidden']
idx_in_seq_list = self.idx_in_seq_list idx_in_seq_list = self.idx_in_seq_list
h_p = self.inputs['H0'] if self.inputs.has_key('H0') else np.zeros( h_p = self.inputs['H0'][self.sorted_seqs] if self.inputs.has_key(
(len(idx_in_seq_list[0]), self.frame_size)) 'H0') else np.zeros((len(idx_in_seq_list[0]), self.frame_size))
num_batch = len(idx_in_seq_list) num_batch = len(idx_in_seq_list)
end_idx = 0 end_idx = 0
for batch_idx in range(num_batch): for batch_idx in range(num_batch):
...@@ -84,8 +85,9 @@ class TestGRUOp(OpTest): ...@@ -84,8 +85,9 @@ class TestGRUOp(OpTest):
return batch_gate, batch_reset_hidden_prev, hidden return batch_gate, batch_reset_hidden_prev, hidden
def set_data(self): def set_data(self):
lod = [[0, 2, 6, self.batch_size]] lod = self.lod
self.idx_in_seq_list = self.seq_to_batch(lod, self.is_reverse) self.idx_in_seq_list, self.sorted_seqs = self.seq_to_batch(
lod, self.is_reverse)
batch_size = self.batch_size batch_size = self.batch_size
frame_size = self.frame_size frame_size = self.frame_size
input = np.random.rand(batch_size, frame_size * 3).astype('float64') input = np.random.rand(batch_size, frame_size * 3).astype('float64')
...@@ -146,8 +148,8 @@ class TestGRUOpReverse(TestGRUOp): ...@@ -146,8 +148,8 @@ class TestGRUOpReverse(TestGRUOp):
def set_confs(self): def set_confs(self):
self.is_reverse = True self.is_reverse = True
self.attrs = { self.attrs = {
'activation': 'identity', 'activation': 'tanh',
'gate_activation': 'sigmoid', 'gate_activation': 'tanh',
'is_reverse': self.is_reverse 'is_reverse': self.is_reverse
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册