From d60fe75ac36d1a34f049acd65b17cbe2d76a2972 Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Thu, 9 Nov 2017 16:23:48 +0800 Subject: [PATCH] follow comments. --- paddle/operators/lstm_op.cc | 30 +++--- paddle/operators/lstm_op.h | 94 ++++++++++--------- .../paddle/v2/framework/tests/test_lstm_op.py | 78 +++++---------- 3 files changed, 83 insertions(+), 119 deletions(-) diff --git a/paddle/operators/lstm_op.cc b/paddle/operators/lstm_op.cc index d99e008447..4cbb60f3fd 100644 --- a/paddle/operators/lstm_op.cc +++ b/paddle/operators/lstm_op.cc @@ -246,25 +246,17 @@ class LSTMGradOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasInput("BatchCellPreAct"), "Input(BatchGate) of LSTM should not be null."); - auto in_g_name = framework::GradVarName("Input"); - if (ctx->HasOutput(in_g_name)) - ctx->SetOutputDim(in_g_name, ctx->GetInputDim("Input")); - - auto w_g_name = framework::GradVarName("Weight"); - if (ctx->HasOutput(w_g_name)) - ctx->SetOutputDim(w_g_name, ctx->GetInputDim("Weight")); - - auto b_g_name = framework::GradVarName("Bias"); - if (ctx->HasOutput(b_g_name)) - ctx->SetOutputDim(b_g_name, ctx->GetInputDim("Bias")); - - auto h0_g_name = framework::GradVarName("H0"); - if (ctx->HasOutput(h0_g_name)) - ctx->SetOutputDim(h0_g_name, ctx->GetInputDim("H0")); - - auto c0_g_name = framework::GradVarName("C0"); - if (ctx->HasOutput(c0_g_name)) - ctx->SetOutputDim(c0_g_name, ctx->GetInputDim("C0")); + auto SetOutGradDim = [&ctx](const std::string& name) { + auto g_name = framework::GradVarName(name); + if (ctx->HasOutput(g_name)) + ctx->SetOutputDim(g_name, ctx->GetInputDim(name)); + }; + + SetOutGradDim("Input"); + SetOutGradDim("Weight"); + SetOutGradDim("Bias"); + SetOutGradDim("H0"); + SetOutGradDim("C0"); } protected: diff --git a/paddle/operators/lstm_op.h b/paddle/operators/lstm_op.h index 26856f4a6e..fca84e2d8f 100644 --- a/paddle/operators/lstm_op.h +++ b/paddle/operators/lstm_op.h @@ -28,6 +28,15 @@ template using EigenMatrix = framework::EigenMatrix; +template +inline void ReorderInitState(const platform::DeviceContext& ctx, + const framework::Tensor& src, const size_t* index, + framework::Tensor* dst, bool indexed_src) { + math::CopyMatrixRowsFunctor row_shuffle; + dst->mutable_data(src.dims(), ctx.GetPlace()); + row_shuffle(ctx, src, index, *dst, indexed_src); +} + template class LSTMKernel : public framework::OpKernel { public: @@ -83,11 +92,13 @@ class LSTMKernel : public framework::OpKernel { } lstm_value.prevStateValue = nullptr; Tensor ordered_c0; + const size_t* order = batch_gate->lod()[2].data(); if (cell_t0) { - math::CopyMatrixRowsFunctor row_shuffle; - ordered_c0.mutable_data(cell_t0->dims(), ctx.GetPlace()); - const size_t* order = batch_gate->lod()[2].data(); - row_shuffle(device_ctx, *cell_t0, order, ordered_c0, true); + // Since the batch computing for LSTM reorders the input sequence + // according to their length. The initialized cell state also needs + // to reorder. + ReorderInitState(device_ctx, *cell_t0, order, &ordered_c0, + true); lstm_value.prevStateValue = ordered_c0.data(); } @@ -123,11 +134,16 @@ class LSTMKernel : public framework::OpKernel { static_cast(1.0), &gate_t, static_cast(1.0)); } else if (hidden_t0) { - math::CopyMatrixRowsFunctor row_shuffle; + // If n == 0 and there is no initialized hidden state, that is to say + // the H0 is zeros, the calculation W_h * H0 will be skiped. + // If n == 0 and there is initialized hidden state, calculate W_h * H0. + + // Since the batch computing for LSTM reorders the input sequence + // according to their length. The initialized hidden state also needs + // to reorder. Tensor ordered_h0; - ordered_h0.mutable_data(hidden_t0->dims(), ctx.GetPlace()); - const size_t* order = batch_gate->lod()[2].data(); - row_shuffle(device_ctx, *hidden_t0, order, ordered_h0, true); + ReorderInitState(device_ctx, *hidden_t0, order, &ordered_h0, + true); math::matmul(device_ctx, ordered_h0, false, *weight, false, static_cast(1.0), &gate_t, static_cast(1.0)); @@ -187,12 +203,16 @@ class LSTMGradKernel : public framework::OpKernel { zero(device_ctx, weight_g, static_cast(0.0)); } + // ordered_h0/c0 is the reordered hidden/cell initialization. + // ordered_h0_g/c0_g is the reordered gradient of hidden/cell + // initialization. Tensor ordered_h0, ordered_c0, ordered_h0_g, ordered_c0_g; - math::CopyMatrixRowsFunctor row_shuffle; const size_t* order = batch_gate->lod()[2].data(); if (c0) { - ordered_c0.mutable_data(c0->dims(), ctx.GetPlace()); - row_shuffle(device_ctx, *c0, order, ordered_c0, true); + ReorderInitState(device_ctx, *c0, order, &ordered_c0, true); + } + if (c0 && c0_g) { + ordered_c0_g.mutable_data(c0_g->dims(), ctx.GetPlace()); } auto in_dims = input->dims(); @@ -231,30 +251,24 @@ class LSTMGradKernel : public framework::OpKernel { math::LoDTensor2BatchFunctor to_batch; - // use the local variable as here. - LoDTensor batch_hidden; - batch_hidden.mutable_data(out_dims, ctx.GetPlace()); - batch_hidden.set_lod(batch_gate->lod()); - to_batch(device_ctx, *hidden_out, batch_hidden, false); - - LoDTensor batch_hidden_g; - batch_hidden_g.mutable_data(out_dims, ctx.GetPlace()); - batch_hidden_g.set_lod(batch_gate->lod()); - to_batch(device_ctx, *hidden_g, batch_hidden_g, false); + auto ToBatch = [&batch_gate, &to_batch]( + const platform::DeviceContext& ctx, const framework::LoDTensor& src, + const framework::DDim& dims, framework::LoDTensor& dst) { + dst.mutable_data(dims, ctx.GetPlace()); + dst.set_lod(batch_gate->lod()); + to_batch(ctx, src, dst, false); + }; - LoDTensor batch_cell; - batch_cell.mutable_data(out_dims, ctx.GetPlace()); - batch_cell.set_lod(batch_gate->lod()); - to_batch(device_ctx, *cell_out, batch_cell, false); + LoDTensor batch_hidden, batch_hidden_g, batch_cell; + ToBatch(device_ctx, *hidden_out, out_dims, batch_hidden); + ToBatch(device_ctx, *hidden_g, out_dims, batch_hidden_g); + ToBatch(device_ctx, *cell_out, out_dims, batch_cell); - LoDTensor batch_cell_g; + LoDTensor batch_cell_g, batch_gate_g; batch_cell_g.mutable_data(out_dims, ctx.GetPlace()); - batch_cell_g.set_lod(batch_gate->lod()); // TODO(qingqing) support the case output cell has gradient. // to_batch(device_ctx, *cell_g, batch_cell_g, false); zero(device_ctx, &batch_cell_g, static_cast(0.0)); - - LoDTensor batch_gate_g; batch_gate_g.mutable_data(batch_gate->dims(), ctx.GetPlace()); batch_gate_g.set_lod(batch_gate->lod()); @@ -289,17 +303,8 @@ class LSTMGradKernel : public framework::OpKernel { lstm_value.prevStateValue = cell_pre.data(); lstm_grad.prevStateGrad = cell_pre_g.data(); } else { - if (c0) { - lstm_value.prevStateValue = ordered_c0.data(); - } else { - lstm_value.prevStateValue = nullptr; - } - if (c0 && c0_g) { - ordered_c0_g.mutable_data(c0_g->dims(), ctx.GetPlace()); - lstm_grad.prevStateGrad = ordered_c0_g.data(); - } else { - lstm_grad.prevStateGrad = nullptr; - } + lstm_value.prevStateValue = c0 ? ordered_c0.data() : nullptr; + lstm_grad.prevStateGrad = c0_g ? ordered_c0_g.data() : nullptr; } int cur_batch_size = bend - bstart; @@ -323,8 +328,7 @@ class LSTMGradKernel : public framework::OpKernel { } } else { if (h0 && weight_g) { - ordered_h0.mutable_data(h0->dims(), ctx.GetPlace()); - row_shuffle(device_ctx, *h0, order, ordered_h0, true); + ReorderInitState(device_ctx, *h0, order, &ordered_h0, true); math::matmul(device_ctx, ordered_h0, true, gate_g, false, static_cast(1.0), weight_g, static_cast(1.0)); @@ -359,12 +363,10 @@ class LSTMGradKernel : public framework::OpKernel { } if (h0 && h0_g) { - h0_g->mutable_data(ctx.GetPlace()); - row_shuffle(device_ctx, ordered_h0_g, order, *h0_g, false); + ReorderInitState(device_ctx, ordered_h0_g, order, h0_g, false); } if (c0 && c0_g) { - c0_g->mutable_data(ctx.GetPlace()); - row_shuffle(device_ctx, ordered_c0_g, order, *c0_g, false); + ReorderInitState(device_ctx, ordered_c0_g, order, c0_g, false); } } }; diff --git a/python/paddle/v2/framework/tests/test_lstm_op.py b/python/paddle/v2/framework/tests/test_lstm_op.py index a4bb99cd7d..77f062e8c8 100644 --- a/python/paddle/v2/framework/tests/test_lstm_op.py +++ b/python/paddle/v2/framework/tests/test_lstm_op.py @@ -179,36 +179,6 @@ class TestLstmOp(OpTest): self.check_grad( ['Input', 'Weight', 'Bias'], ['Hidden'], max_relative_error=5e-4) - def test_check_grad_ingore_bias(self): - N = len(self.lod[0]) - 1 - self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') - self.outputs['BatchCellPreAct'] = np.zeros( - (N, self.D)).astype('float64') - self.check_grad( - ['Input', 'Weight'], ['Hidden'], - max_relative_error=5e-4, - no_grad_set=set('Bias')) - - def test_check_grad_ingore_weight(self): - N = len(self.lod[0]) - 1 - self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') - self.outputs['BatchCellPreAct'] = np.zeros( - (N, self.D)).astype('float64') - self.check_grad( - ['Input', 'Bias'], ['Hidden'], - max_relative_error=5e-4, - no_grad_set=set('Weight')) - - def test_check_grad_ingore_input(self): - N = len(self.lod[0]) - 1 - self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') - self.outputs['BatchCellPreAct'] = np.zeros( - (N, self.D)).astype('float64') - self.check_grad( - ['Weight', 'Bias'], ['Hidden'], - max_relative_error=5e-4, - no_grad_set=set('Input')) - class TestLstmOpHasInitial(TestLstmOp): def set_argument(self): @@ -233,15 +203,35 @@ class TestLstmOpHasInitial(TestLstmOp): ['Input', 'Weight', 'Bias', 'H0', 'C0'], ['Hidden'], max_relative_error=5e-4) - # In order to speed up, skip following testing def test_check_grad_ingore_bias(self): - return + N = len(self.lod[0]) - 1 + self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') + self.outputs['BatchCellPreAct'] = np.zeros( + (N, self.D)).astype('float64') + self.check_grad( + ['Input', 'Weight'], ['Hidden'], + max_relative_error=5e-4, + no_grad_set=set('Bias')) def test_check_grad_ingore_weight(self): - return + N = len(self.lod[0]) - 1 + self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') + self.outputs['BatchCellPreAct'] = np.zeros( + (N, self.D)).astype('float64') + self.check_grad( + ['Input', 'Bias'], ['Hidden'], + max_relative_error=5e-4, + no_grad_set=set('Weight')) def test_check_grad_ingore_input(self): - return + N = len(self.lod[0]) - 1 + self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') + self.outputs['BatchCellPreAct'] = np.zeros( + (N, self.D)).astype('float64') + self.check_grad( + ['Weight', 'Bias'], ['Hidden'], + max_relative_error=5e-4, + no_grad_set=set('Input')) def test_check_grad_ingore_h0(self): N = len(self.lod[0]) - 1 @@ -277,16 +267,6 @@ class TestLstmOpRerverse(TestLstmOp): self.is_reverse = True self.use_peepholes = True - # In order to speed up, skip following testing - def test_check_grad_ingore_bias(self): - return - - def test_check_grad_ingore_weight(self): - return - - def test_check_grad_ingore_input(self): - return - class TestLstmOpNotUsePeepholes(TestLstmOp): def set_argument(self): @@ -301,16 +281,6 @@ class TestLstmOpNotUsePeepholes(TestLstmOp): self.is_reverse = True self.use_peepholes = False - # In order to speed up, skip following testing - def test_check_grad_ingore_bias(self): - return - - def test_check_grad_ingore_weight(self): - return - - def test_check_grad_ingore_input(self): - return - if __name__ == '__main__': unittest.main() -- GitLab