未验证 提交 7a57b3b7 编写于 作者: G Guo Sheng 提交者: GitHub

Merge pull request #5623 from guoshengCS/fix-H0-GRUOp

Fix data order of H0 in GRU Operator
...@@ -24,8 +24,17 @@ ...@@ -24,8 +24,17 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
template <typename Place, typename T>
inline void ReorderInitState(const platform::DeviceContext& ctx,
const framework::Tensor& src, const size_t* index,
framework::Tensor* dst, bool indexed_src) {
math::CopyMatrixRowsFunctor<Place, T> row_shuffle;
dst->mutable_data<T>(src.dims(), ctx.GetPlace());
row_shuffle(ctx, src, index, *dst, indexed_src);
}
template <typename Place, typename T> template <typename Place, typename T>
class GRUKernel : public framework::OpKernel<T> { class GRUKernel : public framework::OpKernel<T> {
...@@ -33,7 +42,6 @@ class GRUKernel : public framework::OpKernel<T> { ...@@ -33,7 +42,6 @@ class GRUKernel : public framework::OpKernel<T> {
void BatchCompute(const framework::ExecutionContext& context) const { void BatchCompute(const framework::ExecutionContext& context) const {
auto* input = context.Input<LoDTensor>("Input"); auto* input = context.Input<LoDTensor>("Input");
auto* h0 = context.Input<Tensor>("H0"); auto* h0 = context.Input<Tensor>("H0");
const T* h0_data = h0 ? h0->data<T>() : nullptr;
auto* weight = context.Input<Tensor>("Weight"); auto* weight = context.Input<Tensor>("Weight");
const T* weight_data = weight->data<T>(); const T* weight_data = weight->data<T>();
auto* bias = context.Input<Tensor>("Bias"); auto* bias = context.Input<Tensor>("Bias");
...@@ -66,7 +74,18 @@ class GRUKernel : public framework::OpKernel<T> { ...@@ -66,7 +74,18 @@ class GRUKernel : public framework::OpKernel<T> {
gru_value.gateWeight = const_cast<T*>(weight_data); gru_value.gateWeight = const_cast<T*>(weight_data);
gru_value.stateWeight = gru_value.stateWeight =
const_cast<T*>(weight_data + 2 * frame_size * frame_size); const_cast<T*>(weight_data + 2 * frame_size * frame_size);
gru_value.prevOutValue = const_cast<T*>(h0_data); Tensor ordered_h0;
const size_t* order = batch_gate->lod()[2].data();
if (h0) {
// Since the batch computing for GRU reorders the input sequences
// according to their length. The initialized cell state also needs
// to reorder.
ReorderInitState<Place, T>(context.device_context(), *h0, order,
&ordered_h0, true);
gru_value.prevOutValue = ordered_h0.data<T>();
} else {
gru_value.prevOutValue = nullptr;
}
auto batch_starts = batch_gate->lod()[0]; auto batch_starts = batch_gate->lod()[0];
size_t num_batch = batch_starts.size() - 1; size_t num_batch = batch_starts.size() - 1;
for (size_t n = 0; n < num_batch; n++) { for (size_t n = 0; n < num_batch; n++) {
...@@ -102,7 +121,6 @@ class GRUGradKernel : public framework::OpKernel<T> { ...@@ -102,7 +121,6 @@ class GRUGradKernel : public framework::OpKernel<T> {
public: public:
void BatchCompute(const framework::ExecutionContext& context) const { void BatchCompute(const framework::ExecutionContext& context) const {
auto* h0 = context.Input<Tensor>("H0"); auto* h0 = context.Input<Tensor>("H0");
const T* h0_data = h0 ? h0->data<T>() : nullptr;
auto* weight = context.Input<Tensor>("Weight"); auto* weight = context.Input<Tensor>("Weight");
const T* weight_data = weight->data<T>(); const T* weight_data = weight->data<T>();
auto* batch_gate = context.Input<LoDTensor>("BatchGate"); auto* batch_gate = context.Input<LoDTensor>("BatchGate");
...@@ -135,6 +153,17 @@ class GRUGradKernel : public framework::OpKernel<T> { ...@@ -135,6 +153,17 @@ class GRUGradKernel : public framework::OpKernel<T> {
zero(dev_ctx, &batch_gate_grad, static_cast<T>(0.0)); zero(dev_ctx, &batch_gate_grad, static_cast<T>(0.0));
zero(dev_ctx, &batch_reset_hidden_prev_grad, static_cast<T>(0.0)); zero(dev_ctx, &batch_reset_hidden_prev_grad, static_cast<T>(0.0));
Tensor ordered_h0, ordered_h0_grad;
const size_t* order = batch_gate->lod()[2].data();
if (h0) {
ReorderInitState<Place, T>(context.device_context(), *h0, order,
&ordered_h0, true);
}
if (h0_grad) {
ordered_h0_grad.mutable_data<T>(h0_grad->dims(), context.GetPlace());
zero(context.device_context(), &ordered_h0_grad, static_cast<T>(0.0));
}
bool is_reverse = context.Attr<bool>("is_reverse"); bool is_reverse = context.Attr<bool>("is_reverse");
batch_hidden_grad.set_lod(batch_hidden->lod()); batch_hidden_grad.set_lod(batch_hidden->lod());
to_batch(dev_ctx, *hidden_grad, batch_hidden_grad, false, is_reverse); to_batch(dev_ctx, *hidden_grad, batch_hidden_grad, false, is_reverse);
...@@ -176,14 +205,9 @@ class GRUGradKernel : public framework::OpKernel<T> { ...@@ -176,14 +205,9 @@ class GRUGradKernel : public framework::OpKernel<T> {
batch_reset_hidden_prev_grad.Slice(bstart, bend); batch_reset_hidden_prev_grad.Slice(bstart, bend);
gru_grad.resetOutputGrad = reset_hidden_prev_grad_t.data<T>(); gru_grad.resetOutputGrad = reset_hidden_prev_grad_t.data<T>();
if (n == 0) { if (n == 0) {
gru_value.prevOutValue = const_cast<T*>(h0_data); gru_value.prevOutValue = h0 ? ordered_h0.data<T>() : nullptr;
if (h0_grad) { gru_grad.prevOutGrad =
T* h0_grad_data = h0_grad->mutable_data<T>(context.GetPlace()); h0 && h0_grad ? ordered_h0_grad.data<T>() : nullptr;
zero(dev_ctx, h0_grad, static_cast<T>(0.0));
gru_grad.prevOutGrad = h0_grad_data;
} else {
gru_grad.prevOutGrad = nullptr;
}
} else { } else {
int bstart_pre = static_cast<int>(batch_starts[n - 1]); int bstart_pre = static_cast<int>(batch_starts[n - 1]);
Tensor hidden_prev_t = batch_hidden->Slice(bstart_pre, bstart); Tensor hidden_prev_t = batch_hidden->Slice(bstart_pre, bstart);
...@@ -208,6 +232,10 @@ class GRUGradKernel : public framework::OpKernel<T> { ...@@ -208,6 +232,10 @@ class GRUGradKernel : public framework::OpKernel<T> {
math::ColwiseSum<Place, T> col_sum; math::ColwiseSum<Place, T> col_sum;
col_sum(dev_ctx, batch_gate_grad, bias_grad); col_sum(dev_ctx, batch_gate_grad, bias_grad);
} }
if (h0 && h0_grad) {
ReorderInitState<Place, T>(context.device_context(), ordered_h0_grad,
order, h0_grad, false);
}
} }
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
......
...@@ -6,7 +6,8 @@ from test_lstm_op import identity, sigmoid, tanh, relu ...@@ -6,7 +6,8 @@ from test_lstm_op import identity, sigmoid, tanh, relu
class TestGRUOp(OpTest): class TestGRUOp(OpTest):
batch_size = 9 lod = [[0, 2, 6, 9]]
batch_size = lod[0][-1]
frame_size = 5 frame_size = 5
activate = { activate = {
'identity': identity, 'identity': identity,
...@@ -35,7 +36,7 @@ class TestGRUOp(OpTest): ...@@ -35,7 +36,7 @@ class TestGRUOp(OpTest):
seq_starts[sorted_seqs[i]] + batch_idx) seq_starts[sorted_seqs[i]] + batch_idx)
idx_in_seq.append(idx) idx_in_seq.append(idx)
idx_in_seq_list.append(idx_in_seq) idx_in_seq_list.append(idx_in_seq)
return idx_in_seq_list return idx_in_seq_list, sorted_seqs
def gru_step(self, x, h_p, w, b): def gru_step(self, x, h_p, w, b):
batch_size = x.shape[0] batch_size = x.shape[0]
...@@ -66,8 +67,8 @@ class TestGRUOp(OpTest): ...@@ -66,8 +67,8 @@ class TestGRUOp(OpTest):
batch_hidden = self.outputs['BatchHidden'] batch_hidden = self.outputs['BatchHidden']
hidden = self.outputs['Hidden'] hidden = self.outputs['Hidden']
idx_in_seq_list = self.idx_in_seq_list idx_in_seq_list = self.idx_in_seq_list
h_p = self.inputs['H0'] if self.inputs.has_key('H0') else np.zeros( h_p = self.inputs['H0'][self.sorted_seqs] if self.inputs.has_key(
(len(idx_in_seq_list[0]), self.frame_size)) 'H0') else np.zeros((len(idx_in_seq_list[0]), self.frame_size))
num_batch = len(idx_in_seq_list) num_batch = len(idx_in_seq_list)
end_idx = 0 end_idx = 0
for batch_idx in range(num_batch): for batch_idx in range(num_batch):
...@@ -84,8 +85,9 @@ class TestGRUOp(OpTest): ...@@ -84,8 +85,9 @@ class TestGRUOp(OpTest):
return batch_gate, batch_reset_hidden_prev, hidden return batch_gate, batch_reset_hidden_prev, hidden
def set_data(self): def set_data(self):
lod = [[0, 2, 6, self.batch_size]] lod = self.lod
self.idx_in_seq_list = self.seq_to_batch(lod, self.is_reverse) self.idx_in_seq_list, self.sorted_seqs = self.seq_to_batch(
lod, self.is_reverse)
batch_size = self.batch_size batch_size = self.batch_size
frame_size = self.frame_size frame_size = self.frame_size
input = np.random.rand(batch_size, frame_size * 3).astype('float64') input = np.random.rand(batch_size, frame_size * 3).astype('float64')
...@@ -146,7 +148,7 @@ class TestGRUOpReverse(TestGRUOp): ...@@ -146,7 +148,7 @@ class TestGRUOpReverse(TestGRUOp):
def set_confs(self): def set_confs(self):
self.is_reverse = True self.is_reverse = True
self.attrs = { self.attrs = {
'activation': 'identity', 'activation': 'tanh',
'gate_activation': 'sigmoid', 'gate_activation': 'sigmoid',
'is_reverse': self.is_reverse 'is_reverse': self.is_reverse
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册